ollama source for Momentry Core verification
This commit is contained in:
100
server/auth.go
Normal file
100
server/auth.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
type registryChallenge struct {
|
||||
Realm string
|
||||
Service string
|
||||
Scope string
|
||||
}
|
||||
|
||||
func (r registryChallenge) URL() (*url.URL, error) {
|
||||
redirectURL, err := url.Parse(r.Realm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := redirectURL.Query()
|
||||
values.Add("service", r.Service)
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
values.Add("scope", s)
|
||||
}
|
||||
|
||||
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
nonce, err := auth.NewNonce(rand.Reader, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values.Add("nonce", nonce)
|
||||
|
||||
redirectURL.RawQuery = values.Encode()
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
headers := make(http.Header)
|
||||
signature, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers.Add("Authorization", signature)
|
||||
|
||||
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, ®istryOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
if len(body) > 0 {
|
||||
return "", fmt.Errorf("%d: %s", response.StatusCode, body)
|
||||
} else {
|
||||
return "", fmt.Errorf("%d", response.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
var token api.TokenResponse
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token.Token, nil
|
||||
}
|
||||
113
server/auth_test.go
Normal file
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
568
server/cloud_proxy.go
Normal file
568
server/cloud_proxy.go
Normal file
@@ -0,0 +1,568 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||
defaultCloudProxySigningHost = "ollama.com"
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
cloudProxyClientVersionHeader = "X-Ollama-Client-Version"
|
||||
|
||||
// maxDecompressedBodySize limits the size of a decompressed request body
|
||||
maxDecompressedBodySize = 20 << 20
|
||||
)
|
||||
|
||||
var (
|
||||
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||
cloudProxySignRequest = signCloudProxyRequest
|
||||
cloudProxySigninURL = signinURL
|
||||
)
|
||||
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"connection": {},
|
||||
"content-length": {},
|
||||
"proxy-connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
cloudProxyBaseURL = baseURL
|
||||
cloudProxySigningHost = signingHost
|
||||
|
||||
if overridden {
|
||||
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Decompress zstd-encoded request bodies so we can inspect the model
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
// preserving raw passthrough semantics.
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(model)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||
if c.Request.URL.Path == "/v1/messages" {
|
||||
if hasAnthropicWebSearchTool(body) {
|
||||
c.Set(legacyCloudAnthropicKey, true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(modelName)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
proxyPath := "/v1/models/" + modelRef.Base
|
||||
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||
// stop doing this, we can inline this method.
|
||||
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := baseURL.ResolveReference(&url.URL{
|
||||
Path: path,
|
||||
RawQuery: c.Request.URL.RawQuery,
|
||||
})
|
||||
|
||||
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||
if clientVersion := strings.TrimSpace(version.Version); clientVersion != "" {
|
||||
outReq.Header.Set(cloudProxyClientVersionHeader, clientVersion)
|
||||
}
|
||||
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||
outReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||
slog.Warn("cloud proxy signing failed", "error", err)
|
||||
writeCloudUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||
// we should not enforce a short total timeout for long-lived responses.
|
||||
resp, err := http.DefaultClient.Do(outReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
var bodyWriter http.ResponseWriter = c.Writer
|
||||
var framedWriter *jsonlFramingResponseWriter
|
||||
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
|
||||
// web_search fallback (which is a path we're removing soon). Local
|
||||
// /v1/messages writes one JSON value per streamResponse callback directly
|
||||
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
|
||||
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
|
||||
// unmarshals one JSON value per Write.
|
||||
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
|
||||
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
|
||||
bodyWriter = framedWriter
|
||||
}
|
||||
|
||||
err = copyProxyResponseBody(bodyWriter, resp.Body)
|
||||
if err == nil && framedWriter != nil {
|
||||
err = framedWriter.FlushPending()
|
||||
}
|
||||
if err != nil {
|
||||
ctxErr := c.Request.Context().Err()
|
||||
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
||||
slog.Debug(
|
||||
"cloud proxy response stream closed by client",
|
||||
"path", c.Request.URL.Path,
|
||||
"status", resp.StatusCode,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Warn(
|
||||
"cloud proxy response copy failed",
|
||||
"path", c.Request.URL.Path,
|
||||
"upstream_path", path,
|
||||
"status", resp.StatusCode,
|
||||
"request_context_canceled", ctxErr != nil,
|
||||
"request_context_err", ctxErr,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelJSON, err := json.Marshal(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["model"] = modelJSON
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||
if r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func extractModelField(body []byte) (string, bool) {
|
||||
if len(body) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw, ok := payload["model"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var model string
|
||||
if err := json.Unmarshal(raw, &model); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
return model, model != ""
|
||||
}
|
||||
|
||||
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range payload.Tools {
|
||||
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeCloudUnauthorized(c *gin.Context) {
|
||||
signinURL, err := cloudProxySigninURL()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||
}
|
||||
|
||||
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
challenge := buildCloudSignatureChallenge(req, ts)
|
||||
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||
query := req.URL.Query()
|
||||
query.Set("ts", ts)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||
}
|
||||
|
||||
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||
baseURL = defaultCloudProxyBaseURL
|
||||
signingHost = defaultCloudProxySigningHost
|
||||
|
||||
rawOverride = strings.TrimSpace(rawOverride)
|
||||
if rawOverride == "" {
|
||||
return baseURL, signingHost, false, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawOverride)
|
||||
if err != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||
}
|
||||
if u.User != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||
}
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||
}
|
||||
if u.RawQuery != "" || u.Fragment != "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||
}
|
||||
|
||||
loopback := isLoopbackHost(host)
|
||||
if runMode == gin.ReleaseMode && !loopback {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||
}
|
||||
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||
}
|
||||
|
||||
u.Path = ""
|
||||
u.RawPath = ""
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
|
||||
return u.String(), strings.ToLower(host), true, nil
|
||||
}
|
||||
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
flusher, canFlush := dst.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
if canFlush {
|
||||
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||
// responses don't flush every write and can optimize throughput.
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type jsonlFramingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
pending []byte
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
|
||||
w.pending = append(w.pending, p...)
|
||||
if err := w.flushCompleteLines(); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) FlushPending() error {
|
||||
trailing := bytes.TrimSpace(w.pending)
|
||||
w.pending = nil
|
||||
if len(trailing) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := w.ResponseWriter.Write(trailing)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
|
||||
for {
|
||||
newline := bytes.IndexByte(w.pending, '\n')
|
||||
if newline < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
line := bytes.TrimSpace(w.pending[:newline])
|
||||
w.pending = w.pending[newline+1:]
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := w.ResponseWriter.Write(line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||
tokens := map[string]struct{}{}
|
||||
for _, raw := range header.Values("Connection") {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(strings.ToLower(token))
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := tokens[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
318
server/cloud_proxy_test.go
Normal file
318
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||
src.Add("X-Trace-Hop", "drop-me")
|
||||
src.Add("X-Alt-Hop", "drop-me-too")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-End-To-End", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyRequestHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Keep-Alive"); got != "" {
|
||||
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "X-Upstream-Hop")
|
||||
src.Add("X-Upstream-Hop", "drop-me")
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Server-Trace", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyResponseHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if overridden {
|
||||
t.Fatal("expected override=false for empty input")
|
||||
}
|
||||
if baseURL != defaultCloudProxyBaseURL {
|
||||
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||
}
|
||||
if signingHost != defaultCloudProxySigningHost {
|
||||
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "localhost" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback override in release mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "https://example.com:8443" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "example.com" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`)
|
||||
var compressed bytes.Buffer
|
||||
w, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd writer: %v", err)
|
||||
}
|
||||
if _, err := w.Write(plainBody); err != nil {
|
||||
t.Fatalf("zstd write: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("zstd close: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Track whether the middleware detected the cloud model by checking
|
||||
// if c.Next() was called (non-cloud path) vs c.Abort() (cloud path).
|
||||
nextCalled := false
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
||||
nextCalled = true
|
||||
// Verify the body is decompressed and Content-Encoding is removed.
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
t.Fatal("expected to extract model from decompressed body")
|
||||
}
|
||||
if model != "test-model:cloud" {
|
||||
t.Fatalf("expected model %q, got %q", "test-model:cloud", model)
|
||||
}
|
||||
if enc := c.GetHeader("Content-Encoding"); enc != "" {
|
||||
t.Fatalf("expected Content-Encoding to be removed, got %q", enc)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
// The cloud passthrough middleware should detect the cloud model and
|
||||
// proxy (abort), so the next handler should NOT be called.
|
||||
// However, since there's no actual cloud server to proxy to, the
|
||||
// middleware will attempt to proxy and fail. We just verify it didn't
|
||||
// fall through to c.Next() due to failure to read the compressed body.
|
||||
if nextCalled {
|
||||
t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a body that exceeds the 20MB limit
|
||||
oversized := make([]byte, maxDecompressedBodySize+1024)
|
||||
for i := range oversized {
|
||||
oversized[i] = 'A'
|
||||
}
|
||||
|
||||
var compressed bytes.Buffer
|
||||
w, err := zstd.NewWriter(&compressed)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd writer: %v", err)
|
||||
}
|
||||
if _, err := w.Write(oversized); err != nil {
|
||||
t.Fatalf("zstd write: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("zstd close: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
|
||||
t.Fatal("handler should not be reached for oversized body")
|
||||
})
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) {
|
||||
rec := &chunkRecorder{header: http.Header{}}
|
||||
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||
|
||||
payload := []byte("{\"a\":1}\n{\"b\":2}\n")
|
||||
if n, err := w.Write(payload); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
} else if n != len(payload) {
|
||||
t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload))
|
||||
}
|
||||
|
||||
if err := w.FlushPending(); err != nil {
|
||||
t.Fatalf("FlushPending failed: %v", err)
|
||||
}
|
||||
|
||||
if len(rec.chunks) != 2 {
|
||||
t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks))
|
||||
}
|
||||
if got := string(rec.chunks[0]); got != `{"a":1}` {
|
||||
t.Fatalf("first chunk mismatch: got %q", got)
|
||||
}
|
||||
if got := string(rec.chunks[1]); got != `{"b":2}` {
|
||||
t.Fatalf("second chunk mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) {
|
||||
rec := &chunkRecorder{header: http.Header{}}
|
||||
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
|
||||
|
||||
if _, err := w.Write([]byte("{\"a\":1")); err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
}
|
||||
if len(rec.chunks) != 0 {
|
||||
t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks))
|
||||
}
|
||||
|
||||
if err := w.FlushPending(); err != nil {
|
||||
t.Fatalf("FlushPending failed: %v", err)
|
||||
}
|
||||
if len(rec.chunks) != 1 {
|
||||
t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks))
|
||||
}
|
||||
if got := string(rec.chunks[0]); got != `{"a":1` {
|
||||
t.Fatalf("trailing chunk mismatch: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type chunkRecorder struct {
|
||||
header http.Header
|
||||
status int
|
||||
chunks [][]byte
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) Header() http.Header {
|
||||
return r.header
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) WriteHeader(statusCode int) {
|
||||
r.status = statusCode
|
||||
}
|
||||
|
||||
func (r *chunkRecorder) Write(p []byte) (int, error) {
|
||||
cp := append([]byte(nil), p...)
|
||||
r.chunks = append(r.chunks, cp)
|
||||
return len(p), nil
|
||||
}
|
||||
903
server/create.go
Normal file
903
server/create.go
Normal file
@@ -0,0 +1,903 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoFilesProvided = errors.New("no files provided to convert")
|
||||
errOnlyOneAdapterSupported = errors.New("only one adapter is currently supported")
|
||||
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
||||
errUnknownType = errors.New("unknown type")
|
||||
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
||||
errFilePath = errors.New("file path must be relative")
|
||||
)
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config := &model.ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: model.RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
var r api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
|
||||
for v, digest := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, digest := range r.Adapters {
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
var remote bool
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromRef, err := parseAndValidateModelRef(r.From)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
fromName := fromRef.Name
|
||||
remoteHost := r.RemoteHost
|
||||
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||
remoteHost = cloudProxyBaseURL
|
||||
}
|
||||
|
||||
if remoteHost != "" {
|
||||
ru, err := remoteURL(remoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
config.RemoteModel = fromRef.Base
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote {
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
|
||||
if config.Renderer == "" {
|
||||
config.Renderer = baseConfig.Renderer
|
||||
}
|
||||
if config.Parser == "" {
|
||||
config.Parser = baseConfig.Parser
|
||||
}
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
if config.ModelFormat == "" {
|
||||
config.ModelFormat = baseConfig.ModelFormat
|
||||
}
|
||||
if len(config.Capabilities) == 0 {
|
||||
config.Capabilities = baseConfig.Capabilities
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if r.Files != nil {
|
||||
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyGGUFSupported, errUnknownType} {
|
||||
if errors.Is(err, badReq) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
}
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
var adapterLayers []*layerGGML
|
||||
if !remote && r.Adapters != nil {
|
||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||
if errors.Is(err, badReq) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
}
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(adapterLayers) > 0 {
|
||||
baseLayers = append(baseLayers, adapterLayers...)
|
||||
}
|
||||
|
||||
// Info is not currently exposed by Modelfiles, but allows overriding various
|
||||
// config values
|
||||
if r.Info != nil {
|
||||
caps, ok := r.Info["capabilities"]
|
||||
if ok {
|
||||
switch tcaps := caps.(type) {
|
||||
case []any:
|
||||
caps := make([]string, len(tcaps))
|
||||
for i, c := range tcaps {
|
||||
str, ok := c.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
caps[i] = str
|
||||
}
|
||||
config.Capabilities = append(config.Capabilities, caps...)
|
||||
}
|
||||
}
|
||||
|
||||
strFromInfo := func(k string) string {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(string)
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
vFromInfo := func(k string) float64 {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(float64)
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
config.ModelFamily = strFromInfo("model_family")
|
||||
if config.ModelFamily != "" {
|
||||
config.ModelFamilies = []string{config.ModelFamily}
|
||||
}
|
||||
|
||||
config.BaseName = strFromInfo("base_name")
|
||||
config.FileType = strFromInfo("quantization_level")
|
||||
config.ModelType = strFromInfo("parameter_size")
|
||||
config.ContextLen = int(vFromInfo("context_length"))
|
||||
config.EmbedLen = int(vFromInfo("embedding_length"))
|
||||
}
|
||||
|
||||
if err := createModel(r, name, baseLayers, config, fn); err != nil {
|
||||
if errors.Is(err, errBadTemplate) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() && oldManifest != nil {
|
||||
if err := oldManifest.RemoveLayers(); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
s.refreshModelListCache(name)
|
||||
|
||||
ch <- api.ProgressResponse{Status: "success"}
|
||||
}()
|
||||
|
||||
if r.Stream != nil && !*r.Stream {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func remoteURL(raw string) (string, error) {
|
||||
// Special‑case: user supplied only a path ("/foo/bar").
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("localhost", "11434"),
|
||||
Path: path.Clean(raw),
|
||||
}).String(), nil
|
||||
}
|
||||
|
||||
if !strings.Contains(raw, "://") {
|
||||
raw = "http://" + raw
|
||||
}
|
||||
|
||||
if raw == "ollama.com" || raw == "http://ollama.com" {
|
||||
raw = "https://ollama.com:443"
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse error: %w", err)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
u.Host = "localhost"
|
||||
}
|
||||
|
||||
hostPart, portPart, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
u.Host = net.JoinHostPort(hostPart, portPart)
|
||||
} else {
|
||||
u.Host = net.JoinHostPort(u.Host, "11434")
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
u.Path = path.Clean(u.Path)
|
||||
}
|
||||
|
||||
if u.Path == "/" {
|
||||
u.Path = ""
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
switch detectModelTypeFromFiles(files) {
|
||||
case "safetensors":
|
||||
layers, err := convertFromSafetensors(files, baseLayers, isAdapter, fn)
|
||||
if err != nil {
|
||||
slog.Error("error converting from safetensors", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
return layers, nil
|
||||
case "gguf":
|
||||
if len(files) == 0 {
|
||||
return nil, errNoFilesProvided
|
||||
} else if len(files) > 1 && isAdapter {
|
||||
return nil, errOnlyOneAdapterSupported
|
||||
}
|
||||
|
||||
var digest string
|
||||
var allLayers []*layerGGML
|
||||
for _, v := range files {
|
||||
digest = v
|
||||
layers, err := ggufLayers(digest, fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allLayers = append(allLayers, layers...)
|
||||
}
|
||||
return allLayers, nil
|
||||
default:
|
||||
return nil, errUnknownType
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelTypeFromFiles(files map[string]string) string {
|
||||
for fn := range files {
|
||||
if strings.HasSuffix(fn, ".safetensors") {
|
||||
return "safetensors"
|
||||
} else if strings.HasSuffix(fn, ".gguf") {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := manifest.BlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
}
|
||||
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
slog.Error("error reading file", "error", err)
|
||||
return ""
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
_, err = f.Read(buf)
|
||||
if err != nil {
|
||||
slog.Error("error reading file", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
ct := ggml.DetectContentType(buf)
|
||||
if ct == "gguf" {
|
||||
return "gguf"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
tmpDir, err := os.MkdirTemp(envconfig.Models(), "ollama-safetensors")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
// Set up a root to validate paths
|
||||
root, err := os.OpenRoot(tmpDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
for fp, digest := range files {
|
||||
if !fs.ValidPath(fp) {
|
||||
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
|
||||
}
|
||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
// Path is likely outside the root
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := createLink(blobPath, filepath.Join(tmpDir, fp)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := os.CreateTemp(tmpDir, "fp16")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer t.Close()
|
||||
|
||||
var mediaType string
|
||||
if !isAdapter {
|
||||
fn(api.ProgressResponse{Status: "converting model"})
|
||||
mediaType = "application/vnd.ollama.image.model"
|
||||
if err := convert.ConvertModel(os.DirFS(tmpDir), t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
kv, err := kvFromLayers(baseLayers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "converting adapter"})
|
||||
mediaType = "application/vnd.ollama.image.adapter"
|
||||
if err := convert.ConvertAdapter(os.DirFS(tmpDir), t, kv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := t.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bin, err := layer.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
f, err := ggml.Decode(bin, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers := []*layerGGML{{layer, f}}
|
||||
|
||||
if !isAdapter {
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
}
|
||||
}
|
||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []manifest.Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
ft := layer.GGML.KV().FileType()
|
||||
if quantType == "" && hasSourceFP8Tensors(layer.GGML.KV()) && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" && slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
|
||||
quantType = "Q8_0"
|
||||
}
|
||||
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
want, err := ggml.ParseFileType(quantType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
|
||||
return errors.New("quantization is only supported for F16, BF16 and F32 models")
|
||||
} else if ft != want {
|
||||
layer, err = quantizeLayer(layer, quantType, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name())
|
||||
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
|
||||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||
|
||||
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
|
||||
// TODO: abstract this into a registry/lookup table when multiple models
|
||||
// need architecture-based renderer/parser/stop defaults.
|
||||
if config.Renderer == "" || config.Parser == "" {
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
r.Parameters = make(map[string]any)
|
||||
}
|
||||
r.Parameters["stop"] = []string{"<turn|>"}
|
||||
}
|
||||
case "laguna":
|
||||
config.Renderer = cmp.Or(config.Renderer, "laguna")
|
||||
config.Parser = cmp.Or(config.Parser, "laguna")
|
||||
case "nemotron_h", "nemotron_h_moe", "nemotron_h_omni":
|
||||
config.Renderer = cmp.Or(config.Renderer, "nemotron-3-nano")
|
||||
config.Parser = cmp.Or(config.Parser, "nemotron-3-nano")
|
||||
}
|
||||
}
|
||||
}
|
||||
layers = append(layers, layer.Layer)
|
||||
}
|
||||
|
||||
if r.Template != "" {
|
||||
layers, err = setTemplate(layers, r.Template)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.System != "" {
|
||||
layers, err = setSystem(layers, r.System)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.License != nil {
|
||||
switch l := r.License.(type) {
|
||||
case string:
|
||||
if l != "" {
|
||||
layers, err = setLicense(layers, l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case any:
|
||||
var licenses []string
|
||||
b, _ := json.Marshal(l) // re-marshal to JSON
|
||||
if err := json.Unmarshal(b, &licenses); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range licenses {
|
||||
layers, err = setLicense(layers, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown license type: %T", l)
|
||||
}
|
||||
}
|
||||
|
||||
layers, err = setParameters(layers, r.Parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers, err = setMessages(layers, r.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(layers, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSourceFP8Tensors(kv ggml.KV) bool {
|
||||
return kv.String("source_quantization") == "hf_fp8" && len(kv.Strings("source_fp8_tensors")) > 0
|
||||
}
|
||||
|
||||
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
|
||||
ft := layer.GGML.KV().FileType()
|
||||
var doneBytes atomic.Uint64
|
||||
totalBytes := uint64(layer.Size) - layer.GGML.Tensors().Offset
|
||||
fnWrap := func(n uint64) {
|
||||
done := doneBytes.Add(n)
|
||||
progress := float32(done) / float32(totalBytes)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0000000000000000000", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
|
||||
}
|
||||
ftype, err := ggml.ParseFileType(quantizeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fp, err := os.Open(blob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
temp, err := os.CreateTemp(filepath.Dir(blob), quantizeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if err := quantize(fp, temp, layer.GGML, ftype, fnWrap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := ggml.Decode(temp, 1024)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
|
||||
return nil, err
|
||||
}
|
||||
return &layerGGML{newLayer, f}, nil
|
||||
}
|
||||
|
||||
func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
sr := io.NewSectionReader(blob, 0, 512)
|
||||
contentType, err := detectContentType(sr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if contentType != "gguf" {
|
||||
slog.Error(fmt.Sprintf("unsupported content type: %s", contentType))
|
||||
return nil, errOnlyGGUFSupported
|
||||
}
|
||||
|
||||
f, err := ggml.Decode(blob, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
|
||||
// if a model has vision.block_count but not block_count, it is a standalone vision model
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := layer.Remove(); err != nil {
|
||||
slog.Warn("couldn't remove blob", "digest", layer.Digest, "error", err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
}
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.params" {
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn, err := os.Open(digestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fn.Close()
|
||||
|
||||
var existing map[string]any
|
||||
if err := json.NewDecoder(fn).Decode(&existing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range existing {
|
||||
if _, exists := p[k]; exists {
|
||||
continue
|
||||
}
|
||||
p[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.params")
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
fmt.Printf("removing old messages\n")
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.messages")
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
}
|
||||
config.RootFS.DiffIDs = digests
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &layer, nil
|
||||
}
|
||||
|
||||
func createLink(src, dst string) error {
|
||||
// make any subdirs for dst
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = os.Remove(dst)
|
||||
if err := os.Symlink(src, dst); err != nil {
|
||||
if err := copyFile(src, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
return err
|
||||
}
|
||||
258
server/create_test.go
Normal file
258
server/create_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
return l.Digest
|
||||
}
|
||||
|
||||
// Create a safetensors compatible file with empty JSON content
|
||||
var buf bytes.Buffer
|
||||
headerSize := int64(len("{}"))
|
||||
binary.Write(&buf, binary.LittleEndian, headerSize)
|
||||
buf.WriteString("{}")
|
||||
|
||||
model := makeTemp(buf.String())
|
||||
config := makeTemp(`{
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"vocab_size": 32000
|
||||
}`)
|
||||
tokenizer := makeTemp(`{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filePath string
|
||||
wantErr error
|
||||
}{
|
||||
// Invalid
|
||||
{
|
||||
name: "InvalidRelativePathShallow",
|
||||
filePath: filepath.Join("..", "file.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidRelativePathDeep",
|
||||
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidNestedPath",
|
||||
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "AbsolutePathOutsideRoot",
|
||||
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
||||
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
||||
},
|
||||
{
|
||||
name: "ValidRelativePath",
|
||||
filePath: "model.safetensors",
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create the minimum required file map for convertFromSafetensors
|
||||
files := map[string]string{
|
||||
tt.filePath: model,
|
||||
"config.json": config,
|
||||
"tokenizer.json": tokenizer,
|
||||
}
|
||||
|
||||
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
||||
|
||||
if (tt.wantErr == nil && err != nil) ||
|
||||
(tt.wantErr != nil && err == nil) ||
|
||||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
||||
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/foo/bar",
|
||||
expected: "http://localhost:11434/foo/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path with cleanup",
|
||||
input: "/foo/../bar",
|
||||
expected: "http://localhost:11434/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "root path",
|
||||
input: "/",
|
||||
expected: "http://localhost:11434/",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host without scheme",
|
||||
input: "example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
input: "example.com:8080",
|
||||
expected: "http://example.com:8080",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL",
|
||||
input: "https://example.com:8080/path",
|
||||
expected: "https://example.com:8080/path",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL with path cleanup",
|
||||
input: "https://example.com:8080/path/../other",
|
||||
expected: "https://example.com:8080/other",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "ollama.com special case",
|
||||
input: "ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "http ollama.com special case",
|
||||
input: "http://ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with only host",
|
||||
input: "http://example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with root path cleaned",
|
||||
input: "http://example.com/",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
input: "http://[::1]:namedport", // invalid port
|
||||
expected: "",
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with scheme but no port",
|
||||
input: "http://localhost",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "complex path cleanup",
|
||||
input: "/a/b/../../c/./d",
|
||||
expected: "http://localhost:11434/c/d",
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := remoteURL(tt.input)
|
||||
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL_Idempotent(t *testing.T) {
|
||||
// Test that applying remoteURL twice gives the same result as applying it once
|
||||
testInputs := []string{
|
||||
"/foo/bar",
|
||||
"example.com",
|
||||
"https://example.com:8080/path",
|
||||
"ollama.com",
|
||||
"http://localhost:11434",
|
||||
}
|
||||
|
||||
for _, input := range testInputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
firstResult, err := remoteURL(input)
|
||||
if err != nil {
|
||||
t.Fatalf("first call failed: %v", err)
|
||||
}
|
||||
|
||||
secondResult, err := remoteURL(firstResult)
|
||||
if err != nil {
|
||||
t.Fatalf("second call failed: %v", err)
|
||||
}
|
||||
|
||||
if firstResult != secondResult {
|
||||
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
509
server/download.go
Normal file
509
server/download.go
Normal file
@@ -0,0 +1,509 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
Name string
|
||||
Digest string
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
|
||||
Parts []*blobDownloadPart
|
||||
|
||||
context.CancelFunc
|
||||
|
||||
done chan struct{}
|
||||
err error
|
||||
references atomic.Int32
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed atomic.Int64
|
||||
|
||||
lastUpdatedMu sync.Mutex
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
|
||||
type jsonBlobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(jsonBlobDownloadPart{
|
||||
N: p.N,
|
||||
Offset: p.Offset,
|
||||
Size: p.Size,
|
||||
Completed: p.Completed.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
||||
var j jsonBlobDownloadPart
|
||||
if err := json.Unmarshal(b, &j); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = blobDownloadPart{
|
||||
N: j.N,
|
||||
Offset: j.Offset,
|
||||
Size: j.Size,
|
||||
}
|
||||
p.Completed.Store(j.Completed)
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
numDownloadParts = 16
|
||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
return strings.Join([]string{
|
||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||
}, "-")
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdatedMu.Lock()
|
||||
p.lastUpdated = time.Now()
|
||||
p.lastUpdatedMu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.done = make(chan struct{})
|
||||
|
||||
for _, partFilePath := range partFilePaths {
|
||||
part, err := b.readPart(partFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Total += part.Size
|
||||
b.Completed.Add(part.Completed.Load())
|
||||
b.Parts = append(b.Parts, part)
|
||||
}
|
||||
|
||||
if len(b.Parts) == 0 {
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
|
||||
size := b.Total / numDownloadParts
|
||||
switch {
|
||||
case size < minDownloadPartSize:
|
||||
size = minDownloadPartSize
|
||||
case size > maxDownloadPartSize:
|
||||
size = maxDownloadPartSize
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < b.Total {
|
||||
if offset+size > b.Total {
|
||||
size = b.Total - offset
|
||||
}
|
||||
|
||||
if err := b.newPart(offset, size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset += size
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.Parts) > 0 {
|
||||
slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||
defer close(b.done)
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
|
||||
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
|
||||
var n int
|
||||
return func(ctx context.Context) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
setSparse(file)
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
directURL, err := func() (*url.URL, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
backoff := newBackoff(10 * time.Second)
|
||||
for {
|
||||
// shallow clone opts to be used in the closure
|
||||
// without affecting the outer opts.
|
||||
newOpts := new(registryOptions)
|
||||
*newOpts = *opts
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errMaxRedirectsExceeded
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
if req.URL.Hostname() == requestURL.Hostname() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop at the first redirect that is not
|
||||
// the same hostname as the original
|
||||
// request.
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
|
||||
if err := backoff(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusTemporaryRedirect && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
return resp.Location()
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
continue
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// explicitly close the file so we can rename it
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range b.Parts {
|
||||
if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Rename(file.Name(), b.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed.Add(n)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed.Load() >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
|
||||
if err := b.writePart(part.Name(), &part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Parts = append(b.Parts, &part)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
|
||||
var part blobDownloadPart
|
||||
partFile, err := os.Open(partName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer partFile.Close()
|
||||
|
||||
if err := json.NewDecoder(partFile).Decode(&part); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
part.blobDownload = b
|
||||
return &part, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
|
||||
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer partFile.Close()
|
||||
|
||||
return json.NewEncoder(partFile).Encode(part)
|
||||
}
|
||||
|
||||
func (b *blobDownload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
|
||||
func (b *blobDownload) release() {
|
||||
if b.references.Add(-1) == 0 {
|
||||
b.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
|
||||
b.acquire()
|
||||
defer b.release()
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-b.done:
|
||||
return b.err
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
n model.Name
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := manifest.BlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
fi, err := os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
case err != nil:
|
||||
return false, err
|
||||
default:
|
||||
opts.fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
|
||||
Digest: opts.digest,
|
||||
Total: fi.Size(),
|
||||
Completed: fi.Size(),
|
||||
})
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
go download.Run(context.Background(), requestURL, opts.regOpts)
|
||||
}
|
||||
|
||||
return false, download.Wait(ctx, opts.fn)
|
||||
}
|
||||
26
server/fixblobs.go
Normal file
26
server/fixblobs.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// fixBlobs walks the provided dir and replaces (":") to ("-") in the file
|
||||
// prefix. (e.g. sha256:1234 -> sha256-1234)
|
||||
func fixBlobs(dir string) error {
|
||||
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
baseName := filepath.Base(path)
|
||||
typ, sha, ok := strings.Cut(baseName, ":")
|
||||
if ok && typ == "sha256" {
|
||||
newPath := filepath.Join(filepath.Dir(path), typ+"-"+sha)
|
||||
if err := os.Rename(path, newPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
83
server/fixblobs_test.go
Normal file
83
server/fixblobs_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFixBlobs(t *testing.T) {
|
||||
cases := []struct {
|
||||
path []string
|
||||
want []string
|
||||
}{
|
||||
{path: []string{"sha256-1234"}, want: []string{"sha256-1234"}},
|
||||
{path: []string{"sha256:1234"}, want: []string{"sha256-1234"}},
|
||||
{path: []string{"sha259:5678"}, want: []string{"sha259:5678"}},
|
||||
{path: []string{"sha256:abcd"}, want: []string{"sha256-abcd"}},
|
||||
{path: []string{"x/y/sha256:abcd"}, want: []string{"x/y/sha256-abcd"}},
|
||||
{path: []string{"x:y/sha256:abcd"}, want: []string{"x:y/sha256-abcd"}},
|
||||
{path: []string{"x:y/sha256:abcd"}, want: []string{"x:y/sha256-abcd"}},
|
||||
{path: []string{"x:y/sha256:abcd", "sha256:1234"}, want: []string{"x:y/sha256-abcd", "sha256-1234"}},
|
||||
{path: []string{"x:y/sha256:abcd", "sha256-1234"}, want: []string{"x:y/sha256-abcd", "sha256-1234"}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(strings.Join(tt.path, "|"), func(t *testing.T) {
|
||||
hasColon := slices.ContainsFunc(tt.path, func(s string) bool { return strings.Contains(s, ":") })
|
||||
if hasColon && runtime.GOOS == "windows" {
|
||||
t.Skip("skipping test on windows")
|
||||
}
|
||||
|
||||
rootDir := t.TempDir()
|
||||
for _, path := range tt.path {
|
||||
fullPath := filepath.Join(rootDir, path)
|
||||
fullDir, _ := filepath.Split(fullPath)
|
||||
|
||||
t.Logf("creating dir %s", fullDir)
|
||||
if err := os.MkdirAll(fullDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("writing file %s", fullPath)
|
||||
if err := os.WriteFile(fullPath, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fixBlobs(rootDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := slurpFiles(os.DirFS(rootDir))
|
||||
|
||||
slices.Sort(tt.want)
|
||||
slices.Sort(got)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Fatalf("got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func slurpFiles(fsys fs.FS) []string {
|
||||
var sfs []string
|
||||
fn := func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
sfs = append(sfs, path)
|
||||
return nil
|
||||
}
|
||||
if err := fs.WalkDir(fsys, ".", fn); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sfs
|
||||
}
|
||||
78
server/gemma4_test.go
Normal file
78
server/gemma4_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveGemma4Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model *Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil model falls back to legacy alias",
|
||||
model: nil,
|
||||
want: gemma4RendererLegacy,
|
||||
},
|
||||
{
|
||||
name: "explicit small passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererSmall),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "explicit large passes through",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLarge),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy e4b tag resolves small",
|
||||
model: &Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy 31b tag resolves large",
|
||||
model: &Model{
|
||||
Name: "gemma4:31b-cloud",
|
||||
ShortName: "gemma4:31b-cloud",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
{
|
||||
name: "legacy model type resolves large",
|
||||
model: &Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: gemma4RendererLarge,
|
||||
},
|
||||
{
|
||||
name: "legacy unknown defaults small",
|
||||
model: &Model{
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: gemma4RendererSmall,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveGemma4Renderer(tt.model); got != tt.want {
|
||||
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1050
server/images.go
Normal file
1050
server/images.go
Normal file
File diff suppressed because it is too large
Load Diff
432
server/images_test.go
Normal file
432
server/images_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
|
||||
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
|
||||
|
||||
for _, digest := range []string{recentDigest, oldDigest} {
|
||||
p, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(p, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
oldPath, err := manifest.BlobsPath(oldDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
|
||||
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := PruneLayers(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
recentPath, err := manifest.BlobsPath(recentDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(recentPath); err != nil {
|
||||
t.Fatalf("recent orphan was pruned: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("old orphan still exists: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
testModels := []struct {
|
||||
name string
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with image generation capability via config",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with image and vision capability (image editing)",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image", "vision"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||
},
|
||||
|
||||
{
|
||||
name: "model with completion, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "gemma4 small safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererSmall,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemma4 large safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLarge,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "legacy gemma4 safetensors suppresses vision and audio",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Renderer: gemma4RendererLegacy,
|
||||
Capabilities: []string{"vision", "audio"},
|
||||
},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
compareCapabilities := func(a, b []model.Capability) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aCount := make(map[model.Capability]int)
|
||||
for _, cap := range a {
|
||||
aCount[cap]++
|
||||
}
|
||||
|
||||
bCount := make(map[model.Capability]int)
|
||||
for _, cap := range b {
|
||||
bCount[cap]++
|
||||
}
|
||||
|
||||
for cap, count := range aCount {
|
||||
if bCount[cap] != count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for _, tt := range testModels {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test Capabilities method
|
||||
caps := tt.model.Capabilities()
|
||||
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
checkCaps []model.Capability
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "completion model without tools capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools},
|
||||
expectedErrMsg: "does not support tools",
|
||||
},
|
||||
{
|
||||
name: "model with all needed capabilities",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model missing insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||
expectedErrMsg: "does not support insert",
|
||||
},
|
||||
{
|
||||
name: "model missing vision capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
expectedErrMsg: "does not support vision",
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "unknown capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
{
|
||||
name: "model missing image generation capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedErrMsg: "does not support image generation",
|
||||
},
|
||||
{
|
||||
name: "model with image generation capability",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test CheckCapabilities method
|
||||
err := tt.model.CheckCapabilities(tt.checkCaps...)
|
||||
if tt.expectedErrMsg == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullModelManifest(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
manifest string
|
||||
}{
|
||||
{
|
||||
name: "pretty printed",
|
||||
manifest: `{ "schemaVersion": 2, "mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": { "digest": "sha256:abc", "mediaType": "application/vnd.docker.container.image.v1+json", "size": 50 },
|
||||
"layers": [{ "digest": "sha256:t1", "mediaType": "application/vnd.ollama.image.tensor", "size": 1024, "name": "model.weight" }]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "non-standard field order",
|
||||
manifest: `{"layers":[{"size":999,"digest":"sha256:def","mediaType":"application/vnd.ollama.image.model"}],"schemaVersion":2,"config":{"size":50,"digest":"sha256:abc","mediaType":"application/vnd.docker.container.image.v1+json"},"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(tt.manifest))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
n := model.ParseName("test/model:latest")
|
||||
n.ProtocolScheme = "http"
|
||||
n.Host = strings.TrimPrefix(ts.URL, "http://")
|
||||
|
||||
mf, data, err := pullModelManifest(t.Context(), n, ®istryOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Raw bytes must be byte-for-byte identical to what the server sent
|
||||
if string(data) != tt.manifest {
|
||||
t.Fatalf("raw bytes differ from server response")
|
||||
}
|
||||
|
||||
// SHA256 of returned data must match the expected registry digest
|
||||
expectedDigest := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.manifest)))
|
||||
gotDigest := fmt.Sprintf("%x", sha256.Sum256(data))
|
||||
if gotDigest != expectedDigest {
|
||||
t.Fatalf("digest mismatch\ngot: %s\nwant: %s", gotDigest, expectedDigest)
|
||||
}
|
||||
|
||||
// Parsed manifest must still be usable
|
||||
if mf.SchemaVersion != 2 {
|
||||
t.Fatalf("schemaVersion = %d, want 2", mf.SchemaVersion)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatal("config digest is empty")
|
||||
}
|
||||
if len(mf.Layers) == 0 {
|
||||
t.Fatal("expected at least one layer")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
144
server/inference_request_log.go
Normal file
144
server/inference_request_log.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type inferenceRequestLogger struct {
|
||||
dir string
|
||||
counter uint64
|
||||
}
|
||||
|
||||
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &inferenceRequestLogger{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *Server) initRequestLogging() error {
|
||||
if !envconfig.DebugLogRequests() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||
}
|
||||
|
||||
s.requestLogger = requestLogger
|
||||
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||
if s.requestLogger == nil {
|
||||
return handlers
|
||||
}
|
||||
|
||||
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
method := c.Request.Method
|
||||
host := c.Request.Host
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
l.log(route, method, scheme, host, contentType, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||
if l == nil || l.dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
if host == "" || scheme == "" {
|
||||
base := envconfig.Host()
|
||||
if host == "" {
|
||||
host = base.Host
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = base.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
routeForFilename := sanitizeRouteForFilename(route)
|
||||
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||
curlPath := filepath.Join(l.dir, curlFilename)
|
||||
|
||||
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||
}
|
||||
|
||||
func sanitizeRouteForFilename(route string) string {
|
||||
route = strings.TrimPrefix(route, "/")
|
||||
if route == "" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(route))
|
||||
for _, r := range route {
|
||||
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||
b.WriteRune(r)
|
||||
} else {
|
||||
b.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
544
server/internal/cache/blob/cache.go
vendored
Normal file
544
server/internal/cache/blob/cache.go
vendored
Normal file
@@ -0,0 +1,544 @@
|
||||
// Package blob implements a content-addressable disk cache for blobs and
|
||||
// manifests.
|
||||
package blob
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/internal/names"
|
||||
)
|
||||
|
||||
// Entry contains metadata about a blob in the cache.
|
||||
type Entry struct {
|
||||
Digest Digest
|
||||
Size int64
|
||||
Time time.Time // when added to the cache
|
||||
}
|
||||
|
||||
// DiskCache caches blobs and manifests on disk.
|
||||
//
|
||||
// The cache is rooted at a directory, which is created if it does not exist.
|
||||
//
|
||||
// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
|
||||
// "manifests" subdirectory. A example directory structure might look like:
|
||||
//
|
||||
// <dir>/
|
||||
// blobs/
|
||||
// sha256-<digest> - <blob data>
|
||||
// manifests/
|
||||
// <host>/
|
||||
// <namespace>/
|
||||
// <name>/
|
||||
// <tag> - <manifest data>
|
||||
//
|
||||
// The cache is safe for concurrent use.
|
||||
//
|
||||
// Name casing is preserved in the cache, but is not significant when resolving
|
||||
// names. For example, "Foo" and "foo" are considered the same name.
|
||||
//
|
||||
// The cache is not safe for concurrent use. It guards concurrent writes, but
|
||||
// does not prevent duplicated effort. Because blobs are immutable, duplicate
|
||||
// writes should result in the same file being written to disk.
|
||||
type DiskCache struct {
|
||||
// Dir specifies the top-level directory where blobs and manifest
|
||||
// pointers are stored.
|
||||
dir string
|
||||
now func() time.Time
|
||||
|
||||
testHookBeforeFinalWrite func(f *os.File)
|
||||
}
|
||||
|
||||
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
|
||||
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
|
||||
}
|
||||
|
||||
// Open opens a cache rooted at the given directory. If the directory does not
|
||||
// exist, it is created. If the directory is not a directory, an error is
|
||||
// returned.
|
||||
func Open(dir string) (*DiskCache, error) {
|
||||
if dir == "" {
|
||||
return nil, errors.New("blob: empty directory name")
|
||||
}
|
||||
|
||||
info, err := os.Stat(dir)
|
||||
if err == nil && !info.IsDir() {
|
||||
return nil, fmt.Errorf("%q is not a directory", dir)
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subdirs := []string{"blobs", "manifests"}
|
||||
for _, subdir := range subdirs {
|
||||
if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bmizerany): support shards
|
||||
c := &DiskCache{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, Digest{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
r := io.TeeReader(f, h)
|
||||
data, err = io.ReadAll(io.LimitReader(r, limit))
|
||||
if err != nil {
|
||||
return nil, Digest{}, err
|
||||
}
|
||||
var d Digest
|
||||
h.Sum(d.sum[:0])
|
||||
return data, d, nil
|
||||
}
|
||||
|
||||
//lint:ignore U1000 used for debugging purposes as needed in tests
|
||||
var debug = false
|
||||
|
||||
// debugger returns a function that can be used to add a step to the error message.
|
||||
// The error message will be a list of steps that were taken before the error occurred.
|
||||
// The steps are added in the order they are called.
|
||||
//
|
||||
// To set the error message, call the returned function with an empty string.
|
||||
//
|
||||
//lint:ignore U1000 used for debugging purposes as needed in tests
|
||||
func debugger(err *error) func(step string) {
|
||||
if !debug {
|
||||
return func(string) {}
|
||||
}
|
||||
var steps []string
|
||||
return func(step string) {
|
||||
if step == "" && *err != nil {
|
||||
*err = fmt.Errorf("%q: %w", steps, *err)
|
||||
return
|
||||
}
|
||||
steps = append(steps, step)
|
||||
if len(steps) > 100 {
|
||||
// shift hints in case of a bug that causes a lot of hints
|
||||
copy(steps, steps[1:])
|
||||
steps = steps[:100]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve resolves a name to a digest. The name is expected to
|
||||
// be in either of the following forms:
|
||||
//
|
||||
// @<digest>
|
||||
// <name>@<digest>
|
||||
// <name>
|
||||
//
|
||||
// If a digest is provided, it is returned as is and nothing else happens.
|
||||
//
|
||||
// If a name is provided for a manifest that exists in the cache, the digest
|
||||
// of the manifest is returned. If there is no manifest in the cache, it
|
||||
// returns [fs.ErrNotExist].
|
||||
//
|
||||
// To cover the case where a manifest may change without the cache knowing
|
||||
// (e.g. it was reformatted or modified by hand), the manifest data read and
|
||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||
// these cases.
|
||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||
name, digest := splitNameDigest(name)
|
||||
if digest != "" {
|
||||
return ParseDigest(digest)
|
||||
}
|
||||
|
||||
// We want to address manifests files by digest using Get. That requires
|
||||
// them to be blobs. This cannot be directly accomplished by looking in
|
||||
// the blob store because manifests can change without Ollama knowing
|
||||
// (e.g. a user modifies a manifests by hand then pushes it to update
|
||||
// their model). We also need to support the blob caches inherited from
|
||||
// older versions of Ollama, which do not store manifests in the blob
|
||||
// store, so for these cases, we need to handle adding the manifests to
|
||||
// the blob store, just in time.
|
||||
//
|
||||
// So now we read the manifests file, hash it, and copy it to the blob
|
||||
// store if it's not already there.
|
||||
//
|
||||
// This should be cheap because manifests are small, and accessed
|
||||
// infrequently.
|
||||
file, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
|
||||
data, d, err := readAndSum(file, 1<<20)
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
|
||||
// Ideally we'd read the "manifest" file as a manifest to the blob file,
|
||||
// but we are not changing this yet, so copy the manifest to the blob
|
||||
// store so it can be addressed by digest subsequent calls to Get.
|
||||
if err := PutBytes(c, d, data); err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Put writes a new blob to the cache, identified by its digest. The operation
|
||||
// reads content from r, which must precisely match both the specified size and
|
||||
// digest.
|
||||
//
|
||||
// Concurrent write safety is achieved through file locking. The implementation
|
||||
// guarantees write integrity by enforcing size limits and content validation
|
||||
// before allowing the file to reach its final state.
|
||||
func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
|
||||
return c.copyNamedFile(c.GetFile(d), r, d, size)
|
||||
}
|
||||
|
||||
// Import imports a blob from the provided reader into the cache. It reads the
|
||||
// entire content of the reader, calculates its digest, and stores it in the
|
||||
// cache.
|
||||
//
|
||||
// Import should be considered unsafe for use with untrusted data, such as data
|
||||
// read from a network. The caller is responsible for ensuring the integrity of
|
||||
// the data being imported.
|
||||
func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
|
||||
// users that want to change the temp dir can set TEMPDIR.
|
||||
f, err := os.CreateTemp("", "blob-")
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
// Copy the blob to a temporary file.
|
||||
h := sha256.New()
|
||||
r = io.TeeReader(r, h)
|
||||
n, err := io.Copy(f, r)
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
if n != size {
|
||||
return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
|
||||
}
|
||||
|
||||
// Check the digest.
|
||||
var d Digest
|
||||
h.Sum(d.sum[:0])
|
||||
if err := f.Close(); err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
name := c.GetFile(d)
|
||||
// Rename the temporary file to the final file.
|
||||
if err := os.Rename(f.Name(), name); err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
os.Chtimes(name, c.now(), c.now()) // mainly for tests
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Get retrieves a blob from the cache using the provided digest. The operation
|
||||
// fails if the digest is malformed or if any errors occur during blob
|
||||
// retrieval.
|
||||
func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
name := c.GetFile(d)
|
||||
info, err := os.Stat(name)
|
||||
if err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
return Entry{}, fs.ErrNotExist
|
||||
}
|
||||
return Entry{
|
||||
Digest: d,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Link creates a symbolic reference in the cache that maps the provided name
|
||||
// to a blob identified by its digest, making it retrievable by name using
|
||||
// [Resolve].
|
||||
//
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// TODO(bmizerany): test this happens only if the blob was found to
|
||||
// avoid leaving debris
|
||||
if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy manifest to cache directory.
|
||||
return c.copyNamedFile(manifest, f, d, info.Size())
|
||||
}
|
||||
|
||||
// Unlink unlinks the manifest by name from the cache. If the name is not
|
||||
// found. If a manifest is removed ok will be true, otherwise false. If an
|
||||
// error occurs, it returns ok false, and the error.
|
||||
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = os.Remove(manifest)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return true, err
|
||||
}
|
||||
|
||||
// GetFile returns the absolute path to the file, in the cache, for the given
|
||||
// digest. It does not check if the file exists.
|
||||
//
|
||||
// The returned path should not be stored, used outside the lifetime of the
|
||||
// cache, or interpreted in any way.
|
||||
func (c *DiskCache) GetFile(d Digest) string {
|
||||
filename := fmt.Sprintf("sha256-%x", d.sum)
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
if !yield(pathToName(path), nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pathToName converts a path to a name. It is the inverse of nameToPath. The
|
||||
// path is assumed to be in filepath.ToSlash format.
|
||||
func pathToName(s string) string {
|
||||
s = strings.TrimPrefix(s, "manifests/")
|
||||
rr := []rune(s)
|
||||
for i := len(rr) - 1; i > 0; i-- {
|
||||
if rr[i] == '/' {
|
||||
rr[i] = ':'
|
||||
return string(rr)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// manifestPath finds the first manifest file on disk that matches the given
|
||||
// name using a case-insensitive comparison. If no manifest file is found, it
|
||||
// returns the path where the manifest file would be if it existed.
|
||||
//
|
||||
// If two manifest files exists on disk that match the given name using a
|
||||
// case-insensitive comparison, the one that sorts first, lexically, is
|
||||
// returned.
|
||||
func (c *DiskCache) manifestPath(name string) (string, error) {
|
||||
np, err := nameToPath(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
maybe := filepath.Join("manifests", np)
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.EqualFold(maybe, l) {
|
||||
return filepath.Join(c.dir, l), nil
|
||||
}
|
||||
}
|
||||
return filepath.Join(c.dir, maybe), nil
|
||||
}
|
||||
|
||||
// links returns a sequence of links in the cache in lexical order.
|
||||
func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
// TODO(bmizerany): reuse empty dirnames if exist
|
||||
return func(yield func(string, error) bool) {
|
||||
fsys := os.DirFS(c.dir)
|
||||
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
for _, manifest := range manifests {
|
||||
if !yield(manifest, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
size int64
|
||||
d Digest
|
||||
f *os.File
|
||||
h hash.Hash
|
||||
|
||||
w io.Writer // underlying writer; set by creator
|
||||
n int64
|
||||
err error
|
||||
|
||||
testHookBeforeFinalWrite func(*os.File)
|
||||
}
|
||||
|
||||
func (w *checkWriter) seterr(err error) error {
|
||||
if w.err == nil {
|
||||
w.err = err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Write writes p to the underlying hash and writer. The last write to the
|
||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||
// hash.
|
||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
_, err := w.h.Write(p)
|
||||
if err != nil {
|
||||
return 0, w.seterr(err)
|
||||
}
|
||||
nextSize := w.n + int64(len(p))
|
||||
if nextSize == w.size {
|
||||
// last write. check hash.
|
||||
sum := w.h.Sum(nil)
|
||||
if !bytes.Equal(sum, w.d.sum[:]) {
|
||||
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
|
||||
}
|
||||
if w.testHookBeforeFinalWrite != nil {
|
||||
w.testHookBeforeFinalWrite(w.f)
|
||||
}
|
||||
}
|
||||
if nextSize > w.size {
|
||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||
}
|
||||
n, err := w.w.Write(p)
|
||||
w.n += int64(n)
|
||||
return n, w.seterr(err)
|
||||
}
|
||||
|
||||
// copyNamedFile copies file into name, expecting it to have the given Digest
|
||||
// and size, if that file is not present already.
|
||||
func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// File already exists with correct size. This is good enough.
|
||||
// We can skip expensive hash checks.
|
||||
//
|
||||
// TODO: Do the hash check, but give caller a way to skip it.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy file to cache directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0o666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
cw := &checkWriter{
|
||||
d: out,
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
w: f,
|
||||
|
||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||
}
|
||||
n, err := io.Copy(cw, file)
|
||||
if err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if n < size {
|
||||
f.Truncate(0)
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove cached file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, c.now(), c.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitNameDigest(s string) (name, digest string) {
|
||||
i := strings.LastIndexByte(s, '@')
|
||||
if i < 0 {
|
||||
return s, ""
|
||||
}
|
||||
return s[:i], s[i+1:]
|
||||
}
|
||||
|
||||
var errInvalidName = errors.New("invalid name")
|
||||
|
||||
func nameToPath(name string) (_ string, err error) {
|
||||
n := names.Parse(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", errInvalidName
|
||||
}
|
||||
return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
|
||||
}
|
||||
|
||||
func absJoin(pp ...string) string {
|
||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||
if err != nil {
|
||||
panic(err) // this should never happen
|
||||
}
|
||||
return abs
|
||||
}
|
||||
688
server/internal/cache/blob/cache_test.go
vendored
Normal file
688
server/internal/cache/blob/cache_test.go
vendored
Normal file
@@ -0,0 +1,688 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
debug = true
|
||||
}
|
||||
|
||||
var epoch = func() time.Time {
|
||||
d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
if d.IsZero() {
|
||||
panic("time zero")
|
||||
}
|
||||
return d
|
||||
}()
|
||||
|
||||
func TestOpenErrors(t *testing.T) {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
dir string
|
||||
err string
|
||||
}{
|
||||
{t.TempDir(), ""},
|
||||
{"", "empty directory name"},
|
||||
{exe, "not a directory"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dir, func(t *testing.T) {
|
||||
_, err := Open(tt.dir)
|
||||
if tt.err == "" {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Fatalf("err = %v, want %q", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFile(t *testing.T) {
|
||||
t.Chdir(t.TempDir())
|
||||
|
||||
c, err := Open(".")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
d := mkdigest("1")
|
||||
got := c.GetFile(d)
|
||||
cleaned := filepath.Clean(got)
|
||||
if cleaned != got {
|
||||
t.Fatalf("got is unclean: %q", got)
|
||||
}
|
||||
if !filepath.IsAbs(got) {
|
||||
t.Fatal("got is not absolute")
|
||||
}
|
||||
abs, _ := filepath.Abs(c.dir)
|
||||
if !strings.HasPrefix(got, abs) {
|
||||
t.Fatalf("got is not local to %q", c.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
c, err := Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
now := epoch
|
||||
c.now = func() time.Time { return now }
|
||||
|
||||
checkEntry := entryChecker(t, c)
|
||||
checkFailed := func(err error) {
|
||||
if err == nil {
|
||||
t.Helper()
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
_, err = c.Resolve("invalid")
|
||||
checkFailed(err)
|
||||
|
||||
_, err = c.Resolve("h/n/m:t")
|
||||
checkFailed(err)
|
||||
|
||||
dx := mkdigest("x")
|
||||
|
||||
d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d != dx {
|
||||
t.Fatalf("d = %v, want %v", d, dx)
|
||||
}
|
||||
|
||||
_, err = c.Get(Digest{})
|
||||
checkFailed(err)
|
||||
|
||||
// not committed yet
|
||||
_, err = c.Get(dx)
|
||||
checkFailed(err)
|
||||
|
||||
err = PutBytes(c, dx, "!")
|
||||
checkFailed(err)
|
||||
|
||||
err = PutBytes(c, dx, "x")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkEntry(dx, 1, now)
|
||||
|
||||
t0 := now
|
||||
now = now.Add(1*time.Hour + 1*time.Minute)
|
||||
err = PutBytes(c, dx, "x")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check not updated
|
||||
checkEntry(dx, 1, t0)
|
||||
}
|
||||
|
||||
type sleepFunc func(d time.Duration) time.Time
|
||||
|
||||
func openTester(t *testing.T) (*DiskCache, sleepFunc) {
|
||||
t.Helper()
|
||||
c, err := Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
now := epoch
|
||||
c.now = func() time.Time { return now }
|
||||
return c, func(d time.Duration) time.Time {
|
||||
now = now.Add(d)
|
||||
return now
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestPath(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
c, sleep := openTester(t)
|
||||
|
||||
d1 := mkdigest("1")
|
||||
err := PutBytes(c, d1, "1")
|
||||
check(err)
|
||||
|
||||
err = c.Link("h/n/m:t", d1)
|
||||
check(err)
|
||||
|
||||
t0 := sleep(0)
|
||||
sleep(1 * time.Hour)
|
||||
err = c.Link("h/n/m:t", d1) // nop expected
|
||||
check(err)
|
||||
|
||||
file := must(c.manifestPath("h/n/m:t"))
|
||||
info, err := os.Stat(file)
|
||||
check(err)
|
||||
testutil.CheckTime(t, info.ModTime(), t0)
|
||||
}
|
||||
|
||||
func TestManifestExistsWithoutBlob(t *testing.T) {
|
||||
t.Chdir(t.TempDir())
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
c, err := Open(".")
|
||||
check(err)
|
||||
|
||||
checkEntry := entryChecker(t, c)
|
||||
|
||||
man := must(c.manifestPath("h/n/m:t"))
|
||||
os.MkdirAll(filepath.Dir(man), 0o777)
|
||||
testutil.WriteFile(t, man, "1")
|
||||
|
||||
got, err := c.Resolve("h/n/m:t")
|
||||
check(err)
|
||||
|
||||
want := mkdigest("1")
|
||||
if got != want {
|
||||
t.Fatalf("got = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
e, err := c.Get(got)
|
||||
check(err)
|
||||
checkEntry(got, 1, e.Time)
|
||||
}
|
||||
|
||||
func TestPut(t *testing.T) {
|
||||
c, sleep := openTester(t)
|
||||
|
||||
check := testutil.Checker(t)
|
||||
checkEntry := entryChecker(t, c)
|
||||
|
||||
d := mkdigest("hello, world")
|
||||
err := PutBytes(c, d, "hello")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
got, err := c.Get(d)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("expected error, got %v", got)
|
||||
}
|
||||
|
||||
// Put a valid blob
|
||||
err = PutBytes(c, d, "hello, world")
|
||||
check(err)
|
||||
checkEntry(d, 12, sleep(0))
|
||||
|
||||
// Put a blob with content that does not hash to the digest
|
||||
err = PutBytes(c, d, "hello")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
|
||||
// Put the valid blob back and check it
|
||||
err = PutBytes(c, d, "hello, world")
|
||||
check(err)
|
||||
checkEntry(d, 12, sleep(0))
|
||||
|
||||
// Put a blob that errors during Read
|
||||
err = c.Put(d, &errOnBangReader{s: "!"}, 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
|
||||
// Put valid blob back and check it
|
||||
err = PutBytes(c, d, "hello, world")
|
||||
check(err)
|
||||
checkEntry(d, 12, sleep(0))
|
||||
|
||||
// Put a blob with mismatched size
|
||||
err = c.Put(d, strings.NewReader("hello, world"), 11)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
|
||||
// Final byte does not match the digest (testing commit phase)
|
||||
err = PutBytes(c, d, "hello, world$")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
|
||||
reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
|
||||
// change mode to read-only
|
||||
f.Truncate(0)
|
||||
f.Chmod(0o400)
|
||||
f.Close()
|
||||
f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { f1.Close() })
|
||||
*f = *f1
|
||||
})
|
||||
defer reset()
|
||||
|
||||
err = PutBytes(c, d, "hello, world")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
reset()
|
||||
}
|
||||
|
||||
func TestImport(t *testing.T) {
|
||||
c, _ := openTester(t)
|
||||
|
||||
checkEntry := entryChecker(t, c)
|
||||
|
||||
want := mkdigest("x")
|
||||
got, err := c.Import(strings.NewReader("x"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want != got {
|
||||
t.Fatalf("digest = %v, want %v", got, want)
|
||||
}
|
||||
checkEntry(want, 1, epoch)
|
||||
|
||||
got, err = c.Import(strings.NewReader("x"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want != got {
|
||||
t.Fatalf("digest = %v, want %v", got, want)
|
||||
}
|
||||
checkEntry(want, 1, epoch)
|
||||
}
|
||||
|
||||
func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
|
||||
old := c.testHookBeforeFinalWrite
|
||||
c.testHookBeforeFinalWrite = h
|
||||
return func() { c.testHookBeforeFinalWrite = old }
|
||||
}
|
||||
|
||||
func TestPutGetZero(t *testing.T) {
|
||||
c, sleep := openTester(t)
|
||||
|
||||
check := testutil.Checker(t)
|
||||
checkEntry := entryChecker(t, c)
|
||||
|
||||
d := mkdigest("x")
|
||||
err := PutBytes(c, d, "x")
|
||||
check(err)
|
||||
checkEntry(d, 1, sleep(0))
|
||||
|
||||
err = os.Truncate(c.GetFile(d), 0)
|
||||
check(err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want fs.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutZero(t *testing.T) {
|
||||
c, _ := openTester(t)
|
||||
d := mkdigest("x")
|
||||
err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
|
||||
testutil.Check(t, err)
|
||||
checkNotExists(t, c, d)
|
||||
}
|
||||
|
||||
func TestCommit(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
c, err := Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkEntry := entryChecker(t, c)
|
||||
|
||||
now := epoch
|
||||
c.now = func() time.Time { return now }
|
||||
|
||||
d1 := mkdigest("1")
|
||||
err = c.Link("h/n/m:t", d1)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want fs.ErrNotExist", err)
|
||||
}
|
||||
|
||||
err = PutBytes(c, d1, "1")
|
||||
check(err)
|
||||
|
||||
err = c.Link("h/n/m:t", d1)
|
||||
check(err)
|
||||
|
||||
got, err := c.Resolve("h/n/m:t")
|
||||
check(err)
|
||||
if got != d1 {
|
||||
t.Fatalf("d = %v, want %v", got, d1)
|
||||
}
|
||||
|
||||
// commit again, more than 1 byte
|
||||
d2 := mkdigest("22")
|
||||
err = PutBytes(c, d2, "22")
|
||||
check(err)
|
||||
err = c.Link("h/n/m:t", d2)
|
||||
check(err)
|
||||
checkEntry(d2, 2, now)
|
||||
|
||||
filename := must(c.manifestPath("h/n/m:t"))
|
||||
data, err := os.ReadFile(filename)
|
||||
check(err)
|
||||
if string(data) != "22" {
|
||||
t.Fatalf("data = %q, want %q", data, "22")
|
||||
}
|
||||
|
||||
t0 := now
|
||||
now = now.Add(1 * time.Hour)
|
||||
err = c.Link("h/n/m:t", d2) // same contents; nop
|
||||
check(err)
|
||||
info, err := os.Stat(filename)
|
||||
check(err)
|
||||
testutil.CheckTime(t, info.ModTime(), t0)
|
||||
}
|
||||
|
||||
func TestManifestInvalidBlob(t *testing.T) {
|
||||
c, _ := openTester(t)
|
||||
d := mkdigest("1")
|
||||
err := c.Link("h/n/m:t", d)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
checkNotExists(t, c, d)
|
||||
|
||||
err = PutBytes(c, d, "1")
|
||||
testutil.Check(t, err)
|
||||
err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = c.Link("h/n/m:t", d)
|
||||
if !strings.Contains(err.Error(), "underfoot") {
|
||||
t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestNameReuse(t *testing.T) {
|
||||
t.Run("case-insensitive", func(t *testing.T) {
|
||||
// This should run on all file system types.
|
||||
testManifestNameReuse(t)
|
||||
})
|
||||
t.Run("case-sensitive", func(t *testing.T) {
|
||||
useCaseInsensitiveTempDir(t)
|
||||
testManifestNameReuse(t)
|
||||
})
|
||||
}
|
||||
|
||||
func testManifestNameReuse(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
c, _ := openTester(t)
|
||||
|
||||
d1 := mkdigest("1")
|
||||
err := PutBytes(c, d1, "1")
|
||||
check(err)
|
||||
err = c.Link("h/n/m:t", d1)
|
||||
check(err)
|
||||
|
||||
d2 := mkdigest("22")
|
||||
err = PutBytes(c, d2, "22")
|
||||
check(err)
|
||||
err = c.Link("H/N/M:T", d2)
|
||||
check(err)
|
||||
|
||||
var g [2]Digest
|
||||
g[0], err = c.Resolve("h/n/m:t")
|
||||
check(err)
|
||||
g[1], err = c.Resolve("H/N/M:T")
|
||||
check(err)
|
||||
|
||||
w := [2]Digest{d2, d2}
|
||||
if g != w {
|
||||
t.Fatalf("g = %v, want %v", g, w)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
want := []string{"manifests/h/n/m/t"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("got = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// relink with different case
|
||||
unlinked, err := c.Unlink("h/n/m:t")
|
||||
check(err)
|
||||
if !unlinked {
|
||||
t.Fatal("expected unlinked")
|
||||
}
|
||||
err = c.Link("h/n/m:T", d1)
|
||||
check(err)
|
||||
|
||||
got = got[:0]
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
|
||||
// we should have only one link that is same case as the last link
|
||||
want = []string{"manifests/h/n/m/T"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("got = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestFile(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""},
|
||||
|
||||
// valid names
|
||||
{"h/n/m:t", "/manifests/h/n/m/t"},
|
||||
{"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
|
||||
|
||||
{"%/%/%/%", ""},
|
||||
|
||||
// already a path
|
||||
{"h/n/m/t", ""},
|
||||
|
||||
// refs are not names
|
||||
{"h/n/m:t@sha256-1", ""},
|
||||
{"m@sha256-1", ""},
|
||||
{"n/m:t@sha256-1", ""},
|
||||
}
|
||||
|
||||
c, _ := openTester(t)
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got, err := c.manifestPath(tt.in)
|
||||
if err != nil && tt.want != "" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if err == nil && tt.want == "" {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
dir := filepath.ToSlash(c.dir)
|
||||
got = filepath.ToSlash(got)
|
||||
got = strings.TrimPrefix(got, dir)
|
||||
if got != tt.want {
|
||||
t.Fatalf("got = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNames(t *testing.T) {
|
||||
c, _ := openTester(t)
|
||||
check := testutil.Checker(t)
|
||||
|
||||
check(PutBytes(c, mkdigest("1"), "1"))
|
||||
check(PutBytes(c, mkdigest("2"), "2"))
|
||||
|
||||
check(c.Link("h/n/m:t", mkdigest("1")))
|
||||
check(c.Link("h/n/m:u", mkdigest("2")))
|
||||
|
||||
var got []string
|
||||
for l, err := range c.Links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
}
|
||||
want := []string{"h/n/m:t", "h/n/m:u"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("got = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func mkdigest(s string) Digest {
|
||||
return Digest{sha256.Sum256([]byte(s))}
|
||||
}
|
||||
|
||||
func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
|
||||
t.Helper()
|
||||
_, err := c.Get(d)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want fs.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
|
||||
t.Helper()
|
||||
return func(d Digest, size int64, mod time.Time) {
|
||||
t.Helper()
|
||||
t.Run("checkEntry:"+d.String(), func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
dumpCacheContents(t, c)
|
||||
}
|
||||
}()
|
||||
|
||||
e, err := c.Get(d)
|
||||
if size == 0 && errors.Is(err, fs.ErrNotExist) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if e.Digest != d {
|
||||
t.Errorf("e.Digest = %v, want %v", e.Digest, d)
|
||||
}
|
||||
if e.Size != size {
|
||||
t.Fatalf("e.Size = %v, want %v", e.Size, size)
|
||||
}
|
||||
|
||||
testutil.CheckTime(t, e.Time, mod)
|
||||
info, err := os.Stat(c.GetFile(d))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.Size() != size {
|
||||
t.Fatalf("info.Size = %v, want %v", info.Size(), size)
|
||||
}
|
||||
testutil.CheckTime(t, info.ModTime(), mod)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func must[T any](v T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestNameToPath(t *testing.T) {
|
||||
_, err := nameToPath("h/n/m:t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type errOnBangReader struct {
|
||||
s string
|
||||
n int
|
||||
}
|
||||
|
||||
func (e *errOnBangReader) Read(p []byte) (int, error) {
|
||||
if len(p) < 1 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
if e.n >= len(p) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if e.s[e.n] == '!' {
|
||||
return 0, errors.New("bang")
|
||||
}
|
||||
p[0] = e.s[e.n]
|
||||
e.n++
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func dumpCacheContents(t *testing.T, c *DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
var b strings.Builder
|
||||
fsys := os.DirFS(c.dir)
|
||||
fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
|
||||
t.Helper()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Format like ls:
|
||||
//
|
||||
// ; ls -la
|
||||
// drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
|
||||
// drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
|
||||
|
||||
fmt.Fprintf(&b, " %s % 4d %s %s\n",
|
||||
info.Mode(),
|
||||
info.Size(),
|
||||
info.ModTime().Format("Jan 2 15:04"),
|
||||
path,
|
||||
)
|
||||
return nil
|
||||
})
|
||||
t.Log()
|
||||
t.Logf("cache contents:\n%s", b.String())
|
||||
}
|
||||
93
server/internal/cache/blob/casecheck_test.go
vendored
Normal file
93
server/internal/cache/blob/casecheck_test.go
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func isCaseSensitive(dir string) bool {
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "_casecheck"))
|
||||
}()
|
||||
|
||||
exists := func(file string) bool {
|
||||
_, err := os.Stat(file)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
file := filepath.Join(dir, "_casecheck")
|
||||
FILE := filepath.Join(dir, "_CASECHECK")
|
||||
if exists(file) || exists(FILE) {
|
||||
panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
|
||||
}
|
||||
|
||||
err := os.WriteFile(file, nil, 0o666)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return !exists(FILE)
|
||||
}
|
||||
|
||||
func isCI() bool {
|
||||
return os.Getenv("CI") != ""
|
||||
}
|
||||
|
||||
const volumeHint = `
|
||||
|
||||
Unable to locate case-insensitive TMPDIR on darwin.
|
||||
|
||||
To run tests, create the case-insensitive volume /Volumes/data:
|
||||
|
||||
$ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
|
||||
|
||||
or run with:
|
||||
|
||||
CI=1 go test ./...
|
||||
|
||||
`
|
||||
|
||||
// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
|
||||
// can find one, otherwise it skips the test if the CI environment variable is
|
||||
// set, or GOOS is not darwin.
|
||||
func useCaseInsensitiveTempDir(t *testing.T) bool {
|
||||
if isCaseSensitive(os.TempDir()) {
|
||||
// Use the default temp dir if it is already case-sensitive.
|
||||
return true
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// If darwin, check for the special case-sensitive volume and
|
||||
// use it if available.
|
||||
const volume = "/Volumes/data"
|
||||
_, err := os.Stat(volume)
|
||||
if err == nil {
|
||||
tmpdir := filepath.Join(volume, "tmp")
|
||||
os.MkdirAll(tmpdir, 0o700)
|
||||
t.Setenv("TMPDIR", tmpdir)
|
||||
return true
|
||||
}
|
||||
if isCI() {
|
||||
// Special case darwin in CI; it is not case-sensitive
|
||||
// by default, and we will be testing other platforms
|
||||
// that are case-sensitive, so we'll have the test
|
||||
// being skipped covered there.
|
||||
t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
|
||||
}
|
||||
}
|
||||
|
||||
if !isCI() {
|
||||
// Require devs to always tests with a case-insensitive TMPDIR.
|
||||
|
||||
// TODO(bmizerany): Print platform-specific instructions or
|
||||
// link to docs on that topic.
|
||||
lines := strings.Split(volumeHint, "\n")
|
||||
for _, line := range lines {
|
||||
t.Skip(line)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
73
server/internal/cache/blob/chunked.go
vendored
Normal file
73
server/internal/cache/blob/chunked.go
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Chunk represents a range of bytes in a blob.
|
||||
type Chunk struct {
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// Chunker writes to a blob in chunks.
|
||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
||||
type Chunker struct {
|
||||
digest Digest
|
||||
size int64
|
||||
f *os.File // nil means pre-validated
|
||||
}
|
||||
|
||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
||||
// size in chunks.
|
||||
//
|
||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
||||
name := c.GetFile(d)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
return &Chunker{}, nil
|
||||
}
|
||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Chunker{digest: d, size: size, f: f}, nil
|
||||
}
|
||||
|
||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
||||
// merging the data with the existing blob. It returns an error if any. As a
|
||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
||||
// io.ErrUnexpectedEOF.
|
||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
||||
if c.f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cw := &checkWriter{
|
||||
d: d,
|
||||
size: chunk.Size(),
|
||||
h: sha256.New(),
|
||||
f: c.f,
|
||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
||||
}
|
||||
|
||||
_, err := io.CopyN(cw, r, chunk.Size())
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (c *Chunker) Close() error {
|
||||
return c.f.Close()
|
||||
}
|
||||
99
server/internal/cache/blob/digest.go
vendored
Normal file
99
server/internal/cache/blob/digest.go
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrInvalidDigest = errors.New("invalid digest")
|
||||
|
||||
// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
|
||||
//
|
||||
// It is comparable and can be used as a map key.
|
||||
type Digest struct {
|
||||
sum [32]byte
|
||||
}
|
||||
|
||||
// ParseDigest parses a digest from a string. If the string is not a valid
|
||||
// digest, a call to the returned digest's IsValid method will return false.
|
||||
//
|
||||
// The input string may be in one of two forms:
|
||||
//
|
||||
// - ("sha256-<hex>"), where <hex> is a 64-character hexadecimal string.
|
||||
// - ("sha256:<hex>"), where <hex> is a 64-character hexadecimal string.
|
||||
//
|
||||
// The [Digest.String] method will return the canonical form of the
|
||||
// digest, "sha256:<hex>".
|
||||
func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
|
||||
s := string(v)
|
||||
i := strings.IndexAny(s, ":-")
|
||||
var zero Digest
|
||||
if i < 0 {
|
||||
return zero, ErrInvalidDigest
|
||||
}
|
||||
|
||||
prefix, sum := s[:i], s[i+1:]
|
||||
if prefix != "sha256" || len(sum) != 64 {
|
||||
return zero, ErrInvalidDigest
|
||||
}
|
||||
|
||||
var d Digest
|
||||
_, err := hex.Decode(d.sum[:], []byte(sum))
|
||||
if err != nil {
|
||||
return zero, ErrInvalidDigest
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
|
||||
return Digest{sha256.Sum256([]byte(v))}
|
||||
}
|
||||
|
||||
// String returns the string representation of the digest in the conventional
|
||||
// form "sha256:<hex>".
|
||||
func (d Digest) String() string {
|
||||
return fmt.Sprintf("sha256:%x", d.sum[:])
|
||||
}
|
||||
|
||||
func (d Digest) Short() string {
|
||||
return fmt.Sprintf("%x", d.sum[:4])
|
||||
}
|
||||
|
||||
func (d Digest) Sum() [32]byte {
|
||||
return d.sum
|
||||
}
|
||||
|
||||
func (d Digest) Compare(other Digest) int {
|
||||
return slices.Compare(d.sum[:], other.sum[:])
|
||||
}
|
||||
|
||||
// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
|
||||
// of some content.
|
||||
func (d Digest) IsValid() bool {
|
||||
return d != (Digest{})
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface. It returns an
|
||||
// error if [Digest.IsValid] returns false.
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
|
||||
// works for a zero digest. If [Digest.IsValid] returns true, it returns an
|
||||
// error.
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if *d != (Digest{}) {
|
||||
return errors.New("digest: illegal UnmarshalText on valid digest")
|
||||
}
|
||||
v, err := ParseDigest(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = v
|
||||
return nil
|
||||
}
|
||||
63
server/internal/cache/blob/digest_test.go
vendored
Normal file
63
server/internal/cache/blob/digest_test.go
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseDigest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
valid bool
|
||||
}{
|
||||
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
|
||||
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
|
||||
|
||||
// too short
|
||||
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
|
||||
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
|
||||
|
||||
// too long
|
||||
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
|
||||
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
|
||||
|
||||
// invalid prefix
|
||||
{"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
|
||||
{"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
|
||||
{"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
|
||||
|
||||
// invalid hex
|
||||
{"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
|
||||
{"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got, err := ParseDigest(tt.in)
|
||||
if tt.valid && err != nil {
|
||||
t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
|
||||
}
|
||||
want := "sha256:" + tt.in[7:]
|
||||
if tt.valid && got.String() != want {
|
||||
t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestMarshalText(t *testing.T) {
|
||||
const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
|
||||
var d Digest
|
||||
if err := json.Unmarshal([]byte(s), &d); err != nil {
|
||||
t.Errorf("json.Unmarshal: %v", err)
|
||||
}
|
||||
out, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
t.Errorf("json.Marshal: %v", err)
|
||||
}
|
||||
want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
|
||||
if string(out) != want {
|
||||
t.Errorf("json.Marshal: got %s, want %s", out, want)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
|
||||
t.Errorf("json.Unmarshal: expected error")
|
||||
}
|
||||
}
|
||||
1197
server/internal/client/ollama/registry.go
Normal file
1197
server/internal/client/ollama/registry.go
Normal file
File diff suppressed because it is too large
Load Diff
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// TODO: go:build goexperiment.synctest
|
||||
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPullDownloadTimeout(t *testing.T) {
|
||||
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer t.Log("upstream", r.Method, r.URL.Path)
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
|
||||
io.WriteString(w, `{
|
||||
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
|
||||
}`)
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
|
||||
// Get headers out to client and then hang on the response
|
||||
w.WriteHeader(200)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
// Hang on the response and unblock when the client
|
||||
// gives up
|
||||
<-r.Context().Done()
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s", r.URL.Path)
|
||||
}
|
||||
})
|
||||
rc.ReadTimeout = 100 * time.Millisecond
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- rc.Pull(ctx, "http://example.com/library/smol")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
want := context.DeadlineExceeded
|
||||
if !errors.Is(err, want) {
|
||||
t.Errorf("err = %v, want %v", err, want)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for Pull to finish")
|
||||
}
|
||||
}
|
||||
953
server/internal/client/ollama/registry_test.go
Normal file
953
server/internal/client/ollama/registry_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func ExampleRegistry_cancelOnFirstError() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if err != nil {
|
||||
// Discontinue pulling layers if there is an
|
||||
// error instead of continuing to pull more
|
||||
// data.
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
var r Registry
|
||||
if err := r.Pull(ctx, "model"); err != nil {
|
||||
// panic for demo purposes
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestMarshalJSON(t *testing.T) {
|
||||
// All manifests should contain an "empty" config object.
|
||||
var m Manifest
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
|
||||
t.Error("expected manifest to contain empty config")
|
||||
t.Fatalf("got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
var errRoundTrip = errors.New("forced roundtrip error")
|
||||
|
||||
type recordRoundTripper http.HandlerFunc
|
||||
|
||||
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
rr(w, req)
|
||||
if w.Code == 499 {
|
||||
return nil, errRoundTrip
|
||||
}
|
||||
resp := w.Result()
|
||||
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
|
||||
// set it manually.
|
||||
resp.Request = req
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
||||
//
|
||||
// empty: no data
|
||||
// zero: no layers
|
||||
// single: one layer with the contents "exists"
|
||||
// multiple: two layers with the contents "exists" and "here"
|
||||
// notfound: a layer that does not exist in the cache
|
||||
// null: one null layer (e.g. [null])
|
||||
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
||||
// invalid: a layer with invalid JSON data
|
||||
//
|
||||
// Tests that want to ensure the client does not communicate with the upstream
|
||||
// registry should pass a nil handler, which will cause a panic if
|
||||
// communication is attempted.
|
||||
//
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mklayer := func(data string) *Layer {
|
||||
return &Layer{
|
||||
Digest: importBytes(t, c, data),
|
||||
Size: int64(len(data)),
|
||||
}
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(upstreamRegistry),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
commit := func(name string, layers ...*Layer) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(&Manifest{Layers: layers})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link(name, string(data))
|
||||
}
|
||||
|
||||
link("empty", "")
|
||||
commit("zero")
|
||||
commit("single", mklayer("exists"))
|
||||
commit("multiple", mklayer("exists"), mklayer("present"))
|
||||
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
||||
commit("null", nil)
|
||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||
link("invalid", "!!!!!")
|
||||
|
||||
return r, c
|
||||
}
|
||||
|
||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func checkErrCode(t *testing.T, err error, status int, code string) {
|
||||
t.Helper()
|
||||
var e *Error
|
||||
if !errors.As(err, &e) || e.status != status || e.Code != code {
|
||||
t.Errorf("err = %v; want %v %v", err, status, code)
|
||||
}
|
||||
}
|
||||
|
||||
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
||||
d, err := c.Import(strings.NewReader(data), int64(len(data)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||
return WithTrace(ctx, t), t
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
pushed = true
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
io.Copy(io.Discard, r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent uploads
|
||||
|
||||
var errs []error
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(_ *Layer, n int64, err error) {
|
||||
// uploading one at a time so no need to lock
|
||||
errs = append(errs, err)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushUploadFileOpenError(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, _ int64, err error) {
|
||||
// Remove the file just before it is opened for upload,
|
||||
// but after the initial Stat that happens before the
|
||||
// upload starts
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"null",
|
||||
"!!!",
|
||||
`{"layers":[]}`,
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
exists := blob.DigestFromBytes("exists")
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
|
||||
w.WriteHeader(499) // should not hit manifest endpoint
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
||||
})
|
||||
|
||||
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestInsecureSkipVerify(t *testing.T) {
|
||||
exists := blob.DigestFromBytes("exists")
|
||||
|
||||
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
const name = "library/insecure"
|
||||
|
||||
var rc Registry
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
_, err := rc.Resolve(t.Context(), url)
|
||||
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
||||
t.Errorf("err = %v; want cert verification failure", err)
|
||||
}
|
||||
|
||||
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
||||
_, err = rc.Resolve(t.Context(), url)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
want *Error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors single",
|
||||
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "errors multiple",
|
||||
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "error empty",
|
||||
data: `{"error":""}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error very empty",
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
data: `{"error":"message", "code":"code"}`,
|
||||
want: &Error{Code: "code", Message: "message"},
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
data: `{"error": 1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Error
|
||||
err := json.Unmarshal([]byte(tt.data), &got)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Errorf("Unmarshal() error = %v", err)
|
||||
// fallthrough and check got
|
||||
}
|
||||
if tt.want == nil {
|
||||
tt.want = &Error{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, *tt.want) {
|
||||
t.Errorf("got = %v; want %v", got, *tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseNameExtendedErrors tests that parseName returns errors messages with enough
|
||||
// detail for users to debug naming issues they may encounter. Previous to this
|
||||
// test, the error messages were not very helpful and each problem was reported
|
||||
// as the same message.
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
// make a blob and link it
|
||||
d := blob.DigestFromBytes("{}")
|
||||
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||
check(err)
|
||||
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||
check(err)
|
||||
|
||||
// confirm linked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
check(err)
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
check(err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||
// with the checksum of its sha256 hash. The checksum is:
|
||||
//
|
||||
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||
//
|
||||
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||
// to be the most readable and maintainable approach. The sum is consistently
|
||||
// used in the tests and unique so searches do not yield false positives.
|
||||
|
||||
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
t.Helper()
|
||||
if got := req.URL.Path; got != path {
|
||||
t.Errorf("URL = %q, want %q", got, path)
|
||||
}
|
||||
if req.Method != method {
|
||||
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||
}
|
||||
}
|
||||
|
||||
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(upstream)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
},
|
||||
})
|
||||
|
||||
rc := &Registry{
|
||||
Cache: cache,
|
||||
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, s.Listener.Addr().String())
|
||||
},
|
||||
}},
|
||||
}
|
||||
return rc, ctx
|
||||
}
|
||||
|
||||
func TestPullChunked(t *testing.T) {
|
||||
var steps atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch steps.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
t.Logf("writing c")
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
testutil.Check(t, err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
testutil.Check(t, err)
|
||||
|
||||
if g := steps.Load(); g != 4 {
|
||||
t.Fatalf("got %d steps, want 4", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullCached(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
// Premeptively cache the blob
|
||||
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
check(err)
|
||||
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||
check(err)
|
||||
|
||||
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPullManifestError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var got *Error
|
||||
if !errors.Is(err, ErrModelNotFound) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `!`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var want *json.SyntaxError
|
||||
if !errors.As(err, &want) {
|
||||
t.Fatalf("err = %T, want %T", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerChecksumError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
var written atomic.Int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
var got *Error
|
||||
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||
t.Fatalf("err = %v, want %v", err, got)
|
||||
}
|
||||
|
||||
if g := written.Load(); g != 1 {
|
||||
t.Fatalf("wrote %d bytes, want 1", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullChunksumStreamError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
|
||||
// Write one valid chunksum and one invalid chunksum
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||
fmt.Fprint(w, "sha256:!") // invalid
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(got, ErrIncomplete) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = f.w.Write(p)
|
||||
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||
return
|
||||
}
|
||||
|
||||
func TestPullChunksumStreaming(t *testing.T) {
|
||||
csr, csw := io.Pipe()
|
||||
defer csw.Close()
|
||||
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||
_, err := io.Copy(fw, csr)
|
||||
if err != nil {
|
||||
t.Errorf("copy: %v", err)
|
||||
}
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
update := make(chan int64, 1)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if n > 0 {
|
||||
update <- n
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||
}()
|
||||
|
||||
// Send first chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
if g := <-update; g != 2 {
|
||||
t.Fatalf("got %d, want 2", g)
|
||||
}
|
||||
|
||||
// now send the second chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
if g := <-update; g != 3 {
|
||||
t.Fatalf("got %d, want 3", g)
|
||||
}
|
||||
csw.Close()
|
||||
testutil.Check(t, <-errc)
|
||||
}
|
||||
|
||||
func TestPullChunksumsCached(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1 // force serial processing of chunksums
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the pull after the first chunksum is processed, but before
|
||||
// the second chunksum is processed (which is waiting because
|
||||
// MaxStreams=1). This should cause the second chunksum to error out
|
||||
// leaving the blob incomplete.
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Reset state and pull again to ensure the blob chunks that should
|
||||
// have been cached are, and the remaining chunk was downloaded, making
|
||||
// the blob complete.
|
||||
step.Store(0)
|
||||
var written atomic.Int64
|
||||
var cached atomic.Int64
|
||||
ctx = WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if errors.Is(err, ErrCached) {
|
||||
cached.Add(n)
|
||||
}
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
check(err)
|
||||
|
||||
if g := written.Load(); g != 5 {
|
||||
t.Fatalf("wrote %d bytes, want 3", g)
|
||||
}
|
||||
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||
t.Fatalf("cached %d bytes, want 5", g)
|
||||
}
|
||||
}
|
||||
72
server/internal/client/ollama/trace.go
Normal file
72
server/internal/client/ollama/trace.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
//
|
||||
// The n argument is the number of bytes transferred so far, and err is
|
||||
// any error that has occurred. If n == 0, and err is nil, the download
|
||||
// or upload has just started. If err is [ErrCached], the download or
|
||||
// upload has been skipped because the blob is already present in the
|
||||
// local cache or remote registry, respectively. Otherwise, if err is
|
||||
// non-nil, the download or upload has failed. When l.Size == n, and
|
||||
// err is nil, the download or upload has completed.
|
||||
//
|
||||
// A function assigned must be safe for concurrent use. The function is
|
||||
// called synchronously and so should not block or take long to run.
|
||||
Update func(_ *Layer, n int64, _ error)
|
||||
}
|
||||
|
||||
func (t *Trace) update(l *Layer, n int64, err error) {
|
||||
if t.Update != nil {
|
||||
t.Update(l, n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type traceKey struct{}
|
||||
|
||||
// WithTrace adds a trace to the context for transfer progress reporting.
|
||||
func WithTrace(ctx context.Context, t *Trace) context.Context {
|
||||
old := traceFromContext(ctx)
|
||||
if old == t {
|
||||
// No change, return the original context. This also prevents
|
||||
// infinite recursion below, if the caller passes the same
|
||||
// Trace.
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Create a new Trace that wraps the old one, if any. If we used the
|
||||
// same pointer t, we end up with a recursive structure.
|
||||
composed := &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if old != nil {
|
||||
old.update(l, n, err)
|
||||
}
|
||||
t.update(l, n, err)
|
||||
},
|
||||
}
|
||||
return context.WithValue(ctx, traceKey{}, composed)
|
||||
}
|
||||
|
||||
var emptyTrace = &Trace{}
|
||||
|
||||
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
|
||||
// none is found.
|
||||
//
|
||||
// It never returns nil.
|
||||
func traceFromContext(ctx context.Context) *Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(*Trace)
|
||||
if t == nil {
|
||||
return emptyTrace
|
||||
}
|
||||
return t
|
||||
}
|
||||
45
server/internal/internal/backoff/backoff.go
Normal file
45
server/internal/internal/backoff/backoff.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
var t *time.Timer
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
38
server/internal/internal/backoff/backoff_synctest_test.go
Normal file
38
server/internal/internal/backoff/backoff_synctest_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
last := -1
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
for n, err := range Loop(ctx, 100*time.Millisecond) {
|
||||
if !errors.Is(err, ctx.Err()) {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != last+1 {
|
||||
t.Errorf("n = %d, want %d", n, last+1)
|
||||
}
|
||||
last = n
|
||||
if n > 5 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
if last != 6 {
|
||||
t.Errorf("last = %d, want 6", last)
|
||||
}
|
||||
})
|
||||
}
|
||||
24
server/internal/internal/backoff/backoff_test.go
Normal file
24
server/internal/internal/backoff/backoff_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoopAllocs(t *testing.T) {
|
||||
for i := range 3 {
|
||||
got := testing.AllocsPerRun(1000, func() {
|
||||
for tick := range Loop(t.Context(), 1) {
|
||||
if tick >= i {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
want := float64(0)
|
||||
if i > 0 {
|
||||
want = 3 // due to time.NewTimer
|
||||
}
|
||||
if got > want {
|
||||
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
228
server/internal/internal/names/name.go
Normal file
228
server/internal/internal/names/name.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package names
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||
)
|
||||
|
||||
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
|
||||
|
||||
type Name struct {
|
||||
// Make incomparable to enfoce use of Compare / Equal for
|
||||
// case-insensitive comparisons.
|
||||
_ [0]func()
|
||||
|
||||
h string
|
||||
n string
|
||||
m string
|
||||
t string
|
||||
}
|
||||
|
||||
// Parse parses and assembles a Name from a name string. The
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag }
|
||||
// { model }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// model:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||
// if the name is valid.
|
||||
func Parse(s string) Name {
|
||||
if len(s) > MaxNameLength {
|
||||
return Name{}
|
||||
}
|
||||
|
||||
var n Name
|
||||
var tail string
|
||||
var c byte
|
||||
for {
|
||||
s, tail, c = cutLastAny(s, "/:")
|
||||
switch c {
|
||||
case ':':
|
||||
n.t = tail
|
||||
continue // look for model
|
||||
case '/':
|
||||
n.h, n.n, _ = cutLastAny(s, "/")
|
||||
n.m = tail
|
||||
return n
|
||||
case 0:
|
||||
n.m = tail
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
// model@digest
|
||||
// @digest
|
||||
func Split(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||
// tag parts of a take precedence over fields in b. The model field is left as
|
||||
// is.
|
||||
//
|
||||
// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
|
||||
// if the name is valid.
|
||||
func Merge(a, b Name) Name {
|
||||
a.h = cmp.Or(a.h, b.h)
|
||||
a.n = cmp.Or(a.n, b.n)
|
||||
a.t = cmp.Or(a.t, b.t)
|
||||
return a
|
||||
}
|
||||
|
||||
// IsValid returns true if the name is valid.
|
||||
func (n Name) IsValid() bool {
|
||||
if n.h != "" && !isValidPart(partHost, n.h) {
|
||||
return false
|
||||
}
|
||||
if n.n != "" && !isValidPart(partNamespace, n.n) {
|
||||
return false
|
||||
}
|
||||
if n.t != "" && !isValidPart(partTag, n.t) {
|
||||
return false
|
||||
}
|
||||
|
||||
// at bare minimum, model must be present and valid
|
||||
return n.m != "" && isValidPart(partModel, n.m)
|
||||
}
|
||||
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||
}
|
||||
|
||||
const (
|
||||
partHost = iota
|
||||
partNamespace
|
||||
partModel
|
||||
partTag
|
||||
)
|
||||
|
||||
func isValidPart(kind int, s string) bool {
|
||||
maxlen := 80
|
||||
if kind == partHost {
|
||||
maxlen = 350
|
||||
}
|
||||
if len(s) > maxlen {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case '_', '-':
|
||||
case '.':
|
||||
if kind == partNamespace {
|
||||
return false
|
||||
}
|
||||
case ':':
|
||||
if kind != partHost {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func (n Name) Host() string { return n.h }
|
||||
func (n Name) Namespace() string { return n.n }
|
||||
func (n Name) Model() string { return n.m }
|
||||
func (n Name) Tag() string { return n.t }
|
||||
|
||||
// Compare compares n and o case-insensitively. It returns 0 if n and o are
|
||||
// equal, -1 if n sorts before o, and 1 if n sorts after o.
|
||||
func (n Name) Compare(o Name) int {
|
||||
return cmp.Or(
|
||||
stringsx.CompareFold(n.h, o.h),
|
||||
stringsx.CompareFold(n.n, o.n),
|
||||
stringsx.CompareFold(n.m, o.m),
|
||||
stringsx.CompareFold(n.t, o.t),
|
||||
)
|
||||
}
|
||||
|
||||
// String returns the fully qualified name in the format
|
||||
// <namespace>/<model>:<tag>.
|
||||
func (n Name) String() string {
|
||||
var b strings.Builder
|
||||
if n.h != "" {
|
||||
b.WriteString(n.h)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
if n.n != "" {
|
||||
b.WriteString(n.n)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
b.WriteString(n.m)
|
||||
if n.t != "" {
|
||||
b.WriteByte(':')
|
||||
b.WriteString(n.t)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (n Name) GoString() string {
|
||||
return fmt.Sprintf("<Name %q %q %q %q>", n.h, n.n, n.m, n.t)
|
||||
}
|
||||
|
||||
// cutLastAny is like strings.Cut but scans in reverse for the last character
|
||||
// in chars. If no character is found, before is the empty string and after is
|
||||
// s. The returned sep is the byte value of the character in chars if one was
|
||||
// found; otherwise it is 0.
|
||||
func cutLastAny(s, chars string) (before, after string, sep byte) {
|
||||
i := strings.LastIndexAny(s, chars)
|
||||
if i >= 0 {
|
||||
return s[:i], s[i+1:], s[i]
|
||||
}
|
||||
return "", s, 0
|
||||
}
|
||||
220
server/internal/internal/names/name_test.go
Normal file
220
server/internal/internal/names/name_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package names
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want Name
|
||||
}{
|
||||
{"", Name{}},
|
||||
{"m:t", Name{m: "m", t: "t"}},
|
||||
{"m", Name{m: "m"}},
|
||||
{"/m", Name{m: "m"}},
|
||||
{"/n/m:t", Name{n: "n", m: "m", t: "t"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{"n/m:t", Name{n: "n", m: "m", t: "t"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{strings.Repeat("m", MaxNameLength+1), Name{}},
|
||||
{"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
|
||||
{"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
|
||||
|
||||
// Invalids
|
||||
// TODO: {"n:t/m:t", Name{}},
|
||||
// TODO: {"/h/n/m:t", Name{}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got := Parse(tt.in)
|
||||
if got.Compare(tt.want) != 0 {
|
||||
t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"m:t",
|
||||
"m:t",
|
||||
"m",
|
||||
"n/m",
|
||||
"n/m:t",
|
||||
"n/m",
|
||||
"n/m",
|
||||
"h/n/m:t",
|
||||
"ollama.com/library/_:latest",
|
||||
|
||||
// Special cased to "round trip" without the leading slash.
|
||||
"/m",
|
||||
"/n/m:t",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
s = strings.TrimPrefix(s, "/")
|
||||
if g := Parse(s).String(); g != s {
|
||||
t.Errorf("parse(%q).String() = %q", s, g)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
|
||||
wantScheme string
|
||||
wantName Name
|
||||
wantDigest string
|
||||
}{
|
||||
{"", "", Name{}, ""},
|
||||
{"m", "", Name{m: "m"}, ""},
|
||||
{"http://m", "http", Name{m: "m"}, ""},
|
||||
{"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
|
||||
{"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
scheme, name, digest := Split(tt.in)
|
||||
n := Parse(name)
|
||||
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
cases := []struct {
|
||||
a, b string
|
||||
want string
|
||||
}{
|
||||
{"", "", ""},
|
||||
{"m", "", "m"},
|
||||
{"", "m", ""},
|
||||
{"x", "y", "x"},
|
||||
{"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
|
||||
{"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
|
||||
|
||||
{"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
|
||||
{"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
a, b := Parse(tt.a), Parse(tt.b)
|
||||
got := Merge(a, b)
|
||||
if got.Compare(Parse(tt.want)) != 0 {
|
||||
t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStringRoundTrip(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"m",
|
||||
"m:t",
|
||||
"n/m",
|
||||
"n/m:t",
|
||||
"n/m:t",
|
||||
"n/m",
|
||||
"n/m",
|
||||
"h/n/m:t",
|
||||
"ollama.com/library/_:latest",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
if got := Parse(s).String(); got != s {
|
||||
t.Errorf("parse(%q).String() = %q", s, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var junkName Name
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for range b.N {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
|
||||
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
|
||||
)
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": true,
|
||||
"namespace/model": true,
|
||||
"model": true,
|
||||
|
||||
// long (but valid)
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
|
||||
// too long
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
|
||||
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
|
||||
|
||||
"h/nn/mm:t": true, // bare minimum part sizes
|
||||
|
||||
// unqualified
|
||||
"m": true,
|
||||
"n/m:": true,
|
||||
"h/n/m": true,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
"mm:": true,
|
||||
"/nn/mm": true,
|
||||
"//": false, // empty model
|
||||
"//mm": true,
|
||||
"hh//": false, // empty model
|
||||
"//mm:@": false,
|
||||
"00@": false,
|
||||
"@": false,
|
||||
|
||||
// not starting with alphanum
|
||||
"-hh/nn/mm:tt": false,
|
||||
"hh/-nn/mm:tt": false,
|
||||
"hh/nn/-mm:tt": false,
|
||||
"hh/nn/mm:-tt": false,
|
||||
|
||||
// smells like a flag
|
||||
"-h": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
// colon in non-host part before tag
|
||||
"host/name:space/model:tag": false,
|
||||
}
|
||||
|
||||
func TestParseNameValidation(t *testing.T) {
|
||||
for s, valid := range testCases {
|
||||
got := Parse(s)
|
||||
if got.IsValid() != valid {
|
||||
t.Logf("got: %v", got)
|
||||
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
|
||||
}
|
||||
}
|
||||
}
|
||||
52
server/internal/internal/stringsx/stringsx.go
Normal file
52
server/internal/internal/stringsx/stringsx.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package stringsx provides additional string manipulation functions
|
||||
// that aren't in the standard library's strings package or go4.org/mem.
|
||||
package stringsx
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
|
||||
// like cmp.Compare, but case insensitively.
|
||||
func CompareFold(a, b string) int {
|
||||
// Track our position in both strings
|
||||
ia, ib := 0, 0
|
||||
for ia < len(a) && ib < len(b) {
|
||||
ra, wa := nextRuneLower(a[ia:])
|
||||
rb, wb := nextRuneLower(b[ib:])
|
||||
if ra < rb {
|
||||
return -1
|
||||
}
|
||||
if ra > rb {
|
||||
return 1
|
||||
}
|
||||
ia += wa
|
||||
ib += wb
|
||||
if wa == 0 || wb == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we've reached here, one or both strings are exhausted
|
||||
// The shorter string is "less than" if they match up to this point
|
||||
switch {
|
||||
case ia == len(a) && ib == len(b):
|
||||
return 0
|
||||
case ia == len(a):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// nextRuneLower returns the next rune in the string, lowercased, along with its
|
||||
// original (consumed) width in bytes. If the string is empty, it returns
|
||||
// (utf8.RuneError, 0)
|
||||
func nextRuneLower(s string) (r rune, width int) {
|
||||
r, width = utf8.DecodeRuneInString(s)
|
||||
return unicode.ToLower(r), width
|
||||
}
|
||||
78
server/internal/internal/stringsx/stringsx_test.go
Normal file
78
server/internal/internal/stringsx/stringsx_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package stringsx
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareFold(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
}{
|
||||
// Basic ASCII cases
|
||||
{"", ""},
|
||||
{"a", "a"},
|
||||
{"a", "A"},
|
||||
{"A", "a"},
|
||||
{"a", "b"},
|
||||
{"b", "a"},
|
||||
{"abc", "ABC"},
|
||||
{"ABC", "abc"},
|
||||
{"abc", "abd"},
|
||||
{"abd", "abc"},
|
||||
|
||||
// Length differences
|
||||
{"abc", "ab"},
|
||||
{"ab", "abc"},
|
||||
|
||||
// Unicode cases
|
||||
{"世界", "世界"},
|
||||
{"Hello世界", "hello世界"},
|
||||
{"世界Hello", "世界hello"},
|
||||
{"世界", "世界x"},
|
||||
{"世界x", "世界"},
|
||||
|
||||
// Special case folding examples
|
||||
{"ß", "ss"}, // German sharp s
|
||||
{"fi", "fi"}, // fi ligature
|
||||
{"Σ", "σ"}, // Greek sigma
|
||||
{"İ", "i\u0307"}, // Turkish dotted I
|
||||
|
||||
// Mixed cases
|
||||
{"HelloWorld", "helloworld"},
|
||||
{"HELLOWORLD", "helloworld"},
|
||||
{"helloworld", "HELLOWORLD"},
|
||||
{"HelloWorld", "helloworld"},
|
||||
{"helloworld", "HelloWorld"},
|
||||
|
||||
// Edge cases
|
||||
{" ", " "},
|
||||
{"1", "1"},
|
||||
{"123", "123"},
|
||||
{"!@#", "!@#"},
|
||||
}
|
||||
|
||||
wants := []int{}
|
||||
for _, tt := range tests {
|
||||
got := CompareFold(tt.a, tt.b)
|
||||
want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
|
||||
if got != want {
|
||||
t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
|
||||
}
|
||||
wants = append(wants, want)
|
||||
}
|
||||
|
||||
if n := testing.AllocsPerRun(1000, func() {
|
||||
for i, tt := range tests {
|
||||
if CompareFold(tt.a, tt.b) != wants[i] {
|
||||
panic("unexpected")
|
||||
}
|
||||
}
|
||||
}); n > 0 {
|
||||
t.Errorf("allocs = %v; want 0", int(n))
|
||||
}
|
||||
}
|
||||
201
server/internal/internal/syncs/line.go
Normal file
201
server/internal/internal/syncs/line.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Package syncs provides synchronization primitives.
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var closedChan = func() chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}()
|
||||
|
||||
// Ticket represents a ticket in a sequence of tickets. The zero value is
|
||||
// invalid. Use [Line.Take] to get a valid ticket.
|
||||
//
|
||||
// A Ticket is not safe for concurrent use.
|
||||
type Ticket struct {
|
||||
ahead chan struct{} // ticket ahead of this one
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the ticket before this one is
|
||||
// done.
|
||||
//
|
||||
// It is incorrect to wait on Ready after the ticket is done.
|
||||
func (t *Ticket) Ready() chan struct{} {
|
||||
return cmp.Or(t.ahead, closedChan)
|
||||
}
|
||||
|
||||
// Done signals that this ticket is done and that the next ticket in line can
|
||||
// proceed.
|
||||
//
|
||||
// The first call to [Done] unblocks the ticket after it, if any. Subsequent
|
||||
// calls are no-ops.
|
||||
func (t *Ticket) Done() {
|
||||
if t.ch != nil {
|
||||
close(t.ch)
|
||||
}
|
||||
t.ch = nil
|
||||
}
|
||||
|
||||
// Line is an ordered sequence of tickets waiting for their turn to proceed.
|
||||
//
|
||||
// To get a ticket use [Line.Take].
|
||||
// To signal that a ticket is done use [Ticket.Done].
|
||||
// To wait your turn use [Ticket.Ready].
|
||||
//
|
||||
// A Line is not safe for concurrent use.
|
||||
type Line struct {
|
||||
last chan struct{} // last ticket in line
|
||||
}
|
||||
|
||||
func (q *Line) Take() *Ticket {
|
||||
t := &Ticket{
|
||||
ahead: q.last,
|
||||
ch: make(chan struct{}),
|
||||
}
|
||||
q.last = t.ch
|
||||
return t
|
||||
}
|
||||
|
||||
// RelayReader implements an [io.WriterTo] that yields the passed
|
||||
// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
|
||||
// the order they are taken. Each [io.WriteCloser] blocks until the previous
|
||||
// one is closed, or a call to [RelayReader.CloseWithError] is made.
|
||||
//
|
||||
// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
|
||||
//
|
||||
// It is not safe for concurrent use.
|
||||
type RelayReader struct {
|
||||
line Line
|
||||
t *Ticket
|
||||
w io.Writer
|
||||
n int64
|
||||
|
||||
mu sync.Mutex
|
||||
err error // set by CloseWithError
|
||||
closedCh chan struct{} // closed if err is set
|
||||
}
|
||||
|
||||
var (
|
||||
_ io.Closer = (*RelayReader)(nil)
|
||||
_ io.WriterTo = (*RelayReader)(nil)
|
||||
_ io.Reader = (*RelayReader)(nil)
|
||||
)
|
||||
|
||||
func NewRelayReader() *RelayReader {
|
||||
var q RelayReader
|
||||
q.closedCh = make(chan struct{})
|
||||
q.t = q.line.Take()
|
||||
return &q
|
||||
}
|
||||
|
||||
// CloseWithError terminates the line, unblocking any writer waiting for its
|
||||
// turn with the error, or [io.EOF] if err is nil. It is safe to call
|
||||
// [CloseWithError] multiple times and across multiple goroutines.
|
||||
//
|
||||
// If the line is already closed, [CloseWithError] is a no-op.
|
||||
//
|
||||
// It never returns an error.
|
||||
func (q *RelayReader) CloseWithError(err error) error {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if q.err == nil {
|
||||
q.err = cmp.Or(q.err, err, io.EOF)
|
||||
close(q.closedCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the line. Any writer waiting for its turn will be unblocked
|
||||
// with an [io.ErrClosedPipe] error.
|
||||
//
|
||||
// It never returns an error.
|
||||
func (q *RelayReader) Close() error {
|
||||
return q.CloseWithError(nil)
|
||||
}
|
||||
|
||||
func (q *RelayReader) closed() <-chan struct{} {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
return q.closedCh
|
||||
}
|
||||
|
||||
func (q *RelayReader) Read(p []byte) (int, error) {
|
||||
panic("RelayReader.Read is for show only; use WriteTo")
|
||||
}
|
||||
|
||||
// WriteTo yields the writer w to the first writer in line and blocks until the
|
||||
// first call to [Close].
|
||||
//
|
||||
// It is safe to call [Take] concurrently with [WriteTo].
|
||||
func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
|
||||
select {
|
||||
case <-q.closed():
|
||||
return 0, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
// We have a destination writer; let the relay begin.
|
||||
q.w = dst
|
||||
q.t.Done()
|
||||
<-q.closed()
|
||||
return q.n, nil
|
||||
}
|
||||
|
||||
// Take returns a writer that will be passed to the next writer in line.
|
||||
//
|
||||
// It is not safe for use across multiple goroutines.
|
||||
func (q *RelayReader) Take() io.WriteCloser {
|
||||
return &relayWriter{q: q, t: q.line.Take()}
|
||||
}
|
||||
|
||||
type relayWriter struct {
|
||||
q *RelayReader
|
||||
t *Ticket
|
||||
ready bool
|
||||
}
|
||||
|
||||
var _ io.StringWriter = (*relayWriter)(nil)
|
||||
|
||||
// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
|
||||
// writer is ready. It returns io.ErrClosedPipe if the line is closed before
|
||||
// the writer is ready.
|
||||
func (w *relayWriter) Write(p []byte) (int, error) {
|
||||
if !w.awaitTurn() {
|
||||
return 0, w.q.err
|
||||
}
|
||||
n, err := w.q.w.Write(p)
|
||||
w.q.n += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *relayWriter) WriteString(s string) (int, error) {
|
||||
if !w.awaitTurn() {
|
||||
return 0, w.q.err
|
||||
}
|
||||
return io.WriteString(w.q.w, s)
|
||||
}
|
||||
|
||||
// Close signals that the writer is done, unblocking the next writer in line.
|
||||
func (w *relayWriter) Close() error {
|
||||
w.t.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *relayWriter) awaitTurn() (ok bool) {
|
||||
if t.ready {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-t.t.Ready():
|
||||
t.ready = true
|
||||
return true
|
||||
case <-t.q.closed():
|
||||
return false
|
||||
}
|
||||
}
|
||||
65
server/internal/internal/syncs/line_test.go
Normal file
65
server/internal/internal/syncs/line_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
)
|
||||
|
||||
func TestPipelineReadWriterTo(t *testing.T) {
|
||||
for range 10 {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
q := NewRelayReader()
|
||||
|
||||
tickets := []struct {
|
||||
io.WriteCloser
|
||||
s string
|
||||
}{
|
||||
{q.Take(), "you"},
|
||||
{q.Take(), " say hi,"},
|
||||
{q.Take(), " and "},
|
||||
{q.Take(), "I say "},
|
||||
{q.Take(), "hello"},
|
||||
}
|
||||
|
||||
rand.Shuffle(len(tickets), func(i, j int) {
|
||||
tickets[i], tickets[j] = tickets[j], tickets[i]
|
||||
})
|
||||
|
||||
var g Group
|
||||
for i, t := range tickets {
|
||||
g.Go(func() {
|
||||
defer t.Close()
|
||||
if i%2 == 0 {
|
||||
// Use [relayWriter.WriteString]
|
||||
io.WriteString(t.WriteCloser, t.s)
|
||||
} else {
|
||||
t.Write([]byte(t.s))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var got bytes.Buffer
|
||||
var copyErr error // checked at end
|
||||
g.Go(func() {
|
||||
_, copyErr = io.Copy(&got, q)
|
||||
})
|
||||
|
||||
synctest.Wait()
|
||||
|
||||
q.Close()
|
||||
g.Wait()
|
||||
|
||||
if copyErr != nil {
|
||||
t.Fatal(copyErr)
|
||||
}
|
||||
|
||||
want := "you say hi, and I say hello"
|
||||
if got.String() != want {
|
||||
t.Fatalf("got %q, want %q", got.String(), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
41
server/internal/internal/syncs/syncs.go
Normal file
41
server/internal/internal/syncs/syncs.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Group is a [sync.WaitGroup] with a Go method.
|
||||
type Group struct {
|
||||
wg sync.WaitGroup
|
||||
n atomic.Int64
|
||||
}
|
||||
|
||||
func (g *Group) Go(f func()) {
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
g.n.Add(1) // Now we are running
|
||||
defer func() {
|
||||
g.wg.Done()
|
||||
g.n.Add(-1) // Now we are done
|
||||
}()
|
||||
f()
|
||||
}()
|
||||
}
|
||||
|
||||
// Running returns the number of goroutines that are currently running.
|
||||
//
|
||||
// If a call to [Running] returns zero, and a call to [Wait] is made without
|
||||
// any calls to [Go], then [Wait] will return immediately. This is true even if
|
||||
// a goroutine is started and finishes between the two calls.
|
||||
//
|
||||
// It is possible for [Running] to return non-zero and for [Wait] to return
|
||||
// immediately. This can happen if the all running goroutines finish between
|
||||
// the two calls.
|
||||
func (g *Group) Running() int64 {
|
||||
return g.n.Load()
|
||||
}
|
||||
|
||||
func (g *Group) Wait() {
|
||||
g.wg.Wait()
|
||||
}
|
||||
116
server/internal/manifest/manifest.go
Normal file
116
server/internal/manifest/manifest.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Package manifest provides documentation for the Ollama manifest format.
|
||||
// This package contains no code.
|
||||
//
|
||||
// # Manifests
|
||||
//
|
||||
// A manifest is a JSON object that describes a model. The JSON object has a
|
||||
// single field "layers" which is a list of layers that make up the model.
|
||||
// A layer is a single, logical unit of a model. Layers are stored in the cache
|
||||
// as files with the name of the digest of the layer. Layers are pushed and
|
||||
// pulled from the registry as blobs.
|
||||
//
|
||||
// A layer is represented as a JSON object with the following fields:
|
||||
//
|
||||
// - "digest": The digest of the layer.
|
||||
// - "mediaType": The media type of the layer.
|
||||
// - "size": The size of the layer in bytes.
|
||||
//
|
||||
// Layers are typically stored in a blob store, such as a registry, and are
|
||||
// referenced by their digest. This package does not define how layers are
|
||||
// stored or retrieved.
|
||||
//
|
||||
// # Configuration Layer
|
||||
//
|
||||
// The configuration of a model is represented as a layer with the media type:
|
||||
//
|
||||
// application/vnd.ollama.image.config; type=<type>
|
||||
//
|
||||
// The "type" parameter in the media type specifies the format of the
|
||||
// configuration (e.g., "safetensor" or "gguf").
|
||||
//
|
||||
// There may be only one configuration layer in a model.
|
||||
//
|
||||
// # Template Layer
|
||||
//
|
||||
// The model template is a layer with the media type:
|
||||
//
|
||||
// application/vnd.ollama.image.template; [name=<name>]
|
||||
//
|
||||
// The "name" parameter in the media type specifies the name of the template as
|
||||
// for lookup at runtime. The name is optional and may be omitted. If omitted,
|
||||
// the template is the default template for the model.
|
||||
//
|
||||
// # Tensor Layers
|
||||
//
|
||||
// The tensors of a model are represented as layers with the media type:
|
||||
//
|
||||
// application/vnd.ollama.image.tensor; name=<name>; dtype=<dtype>; shape=<shape>
|
||||
//
|
||||
// The "name" parameter in the media type specifies the name of the tensor as
|
||||
// defined in the model's configuration and are bound only by the rules for
|
||||
// names as defined in the configuration format, as represented by the
|
||||
// configuration's "type".
|
||||
//
|
||||
// The "dtype" parameter in the media type specifies the data type of the tensor
|
||||
// as a string.
|
||||
//
|
||||
// TODO: Define more specifically how to represent data types as strings.
|
||||
//
|
||||
// The "shape" parameter in the media type specifies the shape of the tensor as
|
||||
// a comma-separated list of integers; one per dimension.
|
||||
//
|
||||
// # Tokenization Layers
|
||||
//
|
||||
// The tokenization of a model is represented as a layer with the media type:
|
||||
//
|
||||
// application/vnd.ollama.image.tokenizer
|
||||
//
|
||||
// The configuration of the tokenizer is represented as a layer with the media type:
|
||||
//
|
||||
// application/vnd.ollama.image.tokenizer.config
|
||||
//
|
||||
// # Miscellaneous Layers
|
||||
//
|
||||
// These extra layer mime types are reserved:
|
||||
//
|
||||
// application/vnd.ollama.image.license
|
||||
//
|
||||
// This layer contains one of the many licenses for the model in plain text.
|
||||
//
|
||||
// # Example Manifest
|
||||
//
|
||||
// The following is an example manifest containing a configuration, a model
|
||||
// template, and two tensors (digests shortened for brevity):
|
||||
//
|
||||
// {
|
||||
// "layers": [{
|
||||
// "digest": "sha256:a...",
|
||||
// "mediaType": "application/vnd.ollama.image.config; type=safetensors",
|
||||
// "size": 1234
|
||||
// },{
|
||||
// "digest": "sha256:b...",
|
||||
// "mediaType": "application/vnd.ollama.image.template",
|
||||
// "size": 5678
|
||||
// },{
|
||||
// "digest": "sha256:c...",
|
||||
// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3",
|
||||
// "size": 9012
|
||||
// },{
|
||||
// "digest": "sha256:d...",
|
||||
// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6",
|
||||
// "size": 3456
|
||||
// }]
|
||||
// }
|
||||
//
|
||||
// # Legacy Media Types
|
||||
//
|
||||
// The appliaction/vnd.ollama.image.model media type is deprecated, but will
|
||||
// remain supported for backwards compatibility, for some undefined amount of
|
||||
// time. New models should use the media types defined above.
|
||||
//
|
||||
// # Reserved media types
|
||||
//
|
||||
// The media type prefix "application/vnd.ollama.image." is reserved for
|
||||
// defining new media types for layers known to Ollama. Currently, all other
|
||||
// prefixes are ignored by official Ollama registry clients.
|
||||
package manifest
|
||||
417
server/internal/registry/server.go
Normal file
417
server/internal/registry/server.go
Normal file
@@ -0,0 +1,417 @@
|
||||
// Package registry implements an http.Handler for handling local Ollama API
|
||||
// model management requests. See [Local] for details.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||
)
|
||||
|
||||
// Local implements an http.Handler for handling local Ollama API model
|
||||
// management requests, such as pushing, pulling, and deleting models.
|
||||
//
|
||||
// It can be arranged for all unknown requests to be passed through to a
|
||||
// fallback handler, if one is provided.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
// this handler.
|
||||
Fallback http.Handler
|
||||
|
||||
// Prune, if set, is called to prune the local disk cache after a model
|
||||
// is deleted.
|
||||
Prune func() error // optional
|
||||
}
|
||||
|
||||
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||||
// response code. We want to avoid adding that field to ollama.Error because it
|
||||
// would always be 0 to clients (we don't want to leak the status code in
|
||||
// errors), and so it would be confusing to have a field that is always 0.
|
||||
type serverError struct {
|
||||
Status int `json:"-"`
|
||||
|
||||
// TODO(bmizerany): Decide if we want to keep this and maybe
|
||||
// bring back later.
|
||||
Code string `json:"code"`
|
||||
|
||||
Message string `json:"error"`
|
||||
}
|
||||
|
||||
func (e serverError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Common API errors
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
type statusCodeRecorder struct {
|
||||
_status int // use status() to get the status code
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) WriteHeader(status int) {
|
||||
if r._status == 0 {
|
||||
r._status = status
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
|
||||
r._status = r.status()
|
||||
return r.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
|
||||
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
|
||||
_ http.Flusher = (*statusCodeRecorder)(nil)
|
||||
)
|
||||
|
||||
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a CloseNotifier.
|
||||
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
|
||||
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a Flusher.
|
||||
func (r *statusCodeRecorder) Flush() {
|
||||
r.ResponseWriter.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) status() int {
|
||||
return cmp.Or(r._status, 200)
|
||||
}
|
||||
|
||||
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rec := &statusCodeRecorder{ResponseWriter: w}
|
||||
s.serveHTTP(rec, r)
|
||||
}
|
||||
|
||||
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
var errattr slog.Attr
|
||||
proxied, err := func() (bool, error) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
return false, s.handleDelete(rec, r)
|
||||
case "/api/pull":
|
||||
return false, s.handlePull(rec, r)
|
||||
default:
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
return true, nil
|
||||
}
|
||||
return false, errNotFound
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
// We always log the error, so fill in the error log attribute
|
||||
errattr = slog.String("error", err.Error())
|
||||
|
||||
var e *serverError
|
||||
switch {
|
||||
case errors.As(err, &e):
|
||||
case errors.Is(err, ollama.ErrNameInvalid):
|
||||
e = &serverError{400, "bad_request", err.Error()}
|
||||
default:
|
||||
e = errInternalError
|
||||
}
|
||||
|
||||
data, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
// unreachable
|
||||
panic(err)
|
||||
}
|
||||
rec.Header().Set("Content-Type", "application/json")
|
||||
rec.WriteHeader(e.Status)
|
||||
rec.Write(data)
|
||||
|
||||
// fallthrough to log
|
||||
}
|
||||
|
||||
if !proxied {
|
||||
// we're only responsible for logging if we handled the request
|
||||
var level slog.Level
|
||||
if rec.status() >= 500 {
|
||||
level = slog.LevelError
|
||||
} else if rec.status() >= 400 {
|
||||
level = slog.LevelWarn
|
||||
}
|
||||
|
||||
s.Logger.LogAttrs(r.Context(), level, "http",
|
||||
errattr, // report first in line to make it easy to find
|
||||
|
||||
// TODO(bmizerany): Write a test to ensure that we are logging
|
||||
// all of this correctly. That also goes for the level+error
|
||||
// logic above.
|
||||
slog.Int("status", rec.status()),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int64("content-length", r.ContentLength),
|
||||
slog.String("remote", r.RemoteAddr),
|
||||
slog.String("proto", r.Proto),
|
||||
slog.String("query", r.URL.RawQuery),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type params struct {
|
||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
||||
// but is deprecated. New clients should use [Model] instead.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
DeprecatedName string `json:"name"`
|
||||
|
||||
// Model is the name of the model to push, pull, or delete.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
Model string `json:"model"`
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
//
|
||||
// Deprecated: This field is ignored and only present for this
|
||||
// deprecation message. It should be removed in a future release.
|
||||
//
|
||||
// Users can just use http or https+insecure to show intent to
|
||||
// communicate they want to do insecure things, without awkward and
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// Stream, if true, will make the server send progress updates in a
|
||||
// streaming of JSON objects. If false, the server will send a single
|
||||
// JSON object with the final status as "success", or an error object
|
||||
// if an error occurred.
|
||||
//
|
||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||
// defined to default to true if not present, so we need a way to check
|
||||
// if the client decisively set it to false. So, we use a pointer to a
|
||||
// bool. Gross.
|
||||
//
|
||||
// Use [stream()] to get the correct value for this field.
|
||||
Stream *bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (p params) stream() bool {
|
||||
if p.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *p.Stream
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errModelNotFound
|
||||
}
|
||||
if s.Prune != nil {
|
||||
return s.Prune()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Error string `json:"error,omitempty,omitzero"`
|
||||
Status string `json:"status,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
}
|
||||
|
||||
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
if !p.stream() {
|
||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
return errModelNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var progress []progressUpdateJSON
|
||||
flushProgress := func() {
|
||||
mu.Lock()
|
||||
progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
|
||||
mu.Unlock()
|
||||
for _, p := range progress {
|
||||
enc.Encode(p)
|
||||
}
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
|
||||
start := sync.OnceFunc(func() {
|
||||
flushProgress() // flush initial state
|
||||
t.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if err != nil && !errors.Is(err, ollama.ErrCached) {
|
||||
s.Logger.Error("pulling", "model", p.model(), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for i, p := range progress {
|
||||
if p.Digest == l.Digest {
|
||||
progress[i].Completed = n
|
||||
return
|
||||
}
|
||||
}
|
||||
progress = append(progress, progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Total: l.Size,
|
||||
})
|
||||
}()
|
||||
|
||||
// Block flushing progress updates until every
|
||||
// layer is accounted for. Clients depend on a
|
||||
// complete model size to calculate progress
|
||||
// correctly; if they use an incomplete total,
|
||||
// progress indicators would erratically jump
|
||||
// as new layers are registered.
|
||||
start()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() (err error) {
|
||||
defer func() { done <- err }()
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := s.Client.Pull(ctx, p.model())
|
||||
if canRetry(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
flushProgress()
|
||||
case err := <-done:
|
||||
flushProgress()
|
||||
if err != nil {
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
return &serverError{
|
||||
Status: 404,
|
||||
Code: "not_found",
|
||||
Message: fmt.Sprintf("model %q not found", p.model()),
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Emulate old client pull progress (for now):
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
var zero T
|
||||
|
||||
// Not sure why, but I can't seem to be able to use:
|
||||
//
|
||||
// errors.As(err, &json.UnmarshalTypeError{})
|
||||
//
|
||||
// This is working fine in stdlib, so I'm not sure what rules changed
|
||||
// and why this no longer works here. So, we do it the verbose way.
|
||||
var a *json.UnmarshalTypeError
|
||||
var b *json.SyntaxError
|
||||
if errors.As(err, &a) || errors.As(err, &b) {
|
||||
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var oe *ollama.Error
|
||||
if errors.As(err, &oe) {
|
||||
return oe.Temporary()
|
||||
}
|
||||
s := err.Error()
|
||||
return cmp.Or(
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
strings.Contains(s, "unreachable"),
|
||||
strings.Contains(s, "no route to host"),
|
||||
strings.Contains(s, "connection reset by peer"),
|
||||
)
|
||||
}
|
||||
302
server/internal/registry/server_test.go
Normal file
302
server/internal/registry/server_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
"golang.org/x/tools/txtar"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
panic("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
// bytesResetter is an interface for types that can be reset and return a byte
|
||||
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
|
||||
// etc for the purpose of checking logs.
|
||||
type bytesResetter interface {
|
||||
Bytes() []byte
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := blob.Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := panicOnRoundTrip
|
||||
if upstreamRegistry != nil {
|
||||
s := httptest.NewTLSServer(upstreamRegistry)
|
||||
t.Cleanup(s.Close)
|
||||
tr := s.Client().Transport.(*http.Transport).Clone()
|
||||
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
|
||||
}
|
||||
client = &http.Client{Transport: tr}
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: client,
|
||||
Mask: "example.com/library/_:latest",
|
||||
}
|
||||
|
||||
l := &Local{
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
t.Logf("update: %s %d %v", l.Digest, n, err)
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
|
||||
return s.sendRequest(t, req)
|
||||
}
|
||||
|
||||
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
type invalidReader struct{}
|
||||
|
||||
func (r *invalidReader) Read(p []byte) (int, error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
// captureLogs is a helper to capture logs from the server. It returns a
|
||||
// shallow copy of the server with a new logger and a bytesResetter for the
|
||||
// logs.
|
||||
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
||||
t.Helper()
|
||||
log, logs := testutil.SlogBuffer()
|
||||
l := *s // shallow copy
|
||||
l.Logger = log
|
||||
return &l, logs
|
||||
}
|
||||
|
||||
func TestServerDelete(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
s := newTestServer(t, nil)
|
||||
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal("smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
|
||||
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
s, logs := captureLogs(t, s)
|
||||
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
|
||||
got = s.sendRequest(t, req)
|
||||
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
|
||||
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
|
||||
check(err)
|
||||
if !ok {
|
||||
t.Logf("logs:\n%s", logs)
|
||||
t.Fatalf("expected log to contain ERROR with invalid argument")
|
||||
}
|
||||
}
|
||||
|
||||
//go:embed testdata/registry.txt
|
||||
var registryTXT []byte
|
||||
|
||||
var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// Txtar gets hung up on \r\n line endings, so we need to convert them
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fsys
|
||||
})
|
||||
|
||||
func TestServerPull(t *testing.T) {
|
||||
modelsHandler := http.FileServerFS(registryFS())
|
||||
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v2/library/BOOM/manifests/latest":
|
||||
w.WriteHeader(999)
|
||||
io.WriteString(w, `{"error": "boom"}`)
|
||||
case "/v2/library/unknown/manifests/latest":
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving blob: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
||||
t.Helper()
|
||||
if got.Code != 200 {
|
||||
t.Errorf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
if strings.TrimSpace(gotlines) == "" {
|
||||
gotlines = "<empty>"
|
||||
}
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
for want := range strings.Lines(wantlines) {
|
||||
want = strings.TrimSpace(want)
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Errorf("\t! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Errorf("\t! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"verifying sha256 digest"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"code":"not_found","error":"model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"code":"bad_request","error":"invalid or missing name: \"\""}
|
||||
`)
|
||||
|
||||
// Non-streaming pulls
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
||||
checkResponse(got, `
|
||||
{"status":"success"}
|
||||
!digest
|
||||
!total
|
||||
!completed
|
||||
`)
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
var fellback bool
|
||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fellback = true
|
||||
})
|
||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
if !fellback {
|
||||
t.Fatal("expected Fallback to be called")
|
||||
}
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
t.Helper()
|
||||
|
||||
var printedBody bool
|
||||
errorf := func(format string, args ...any) {
|
||||
t.Helper()
|
||||
if !printedBody {
|
||||
t.Logf("BODY:\n%s", got.Body.String())
|
||||
printedBody = true
|
||||
}
|
||||
t.Errorf(format, args...)
|
||||
}
|
||||
|
||||
if got.Code != status {
|
||||
errorf("Code = %d; want %d", got.Code, status)
|
||||
}
|
||||
|
||||
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
|
||||
var e *ollama.Error
|
||||
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
|
||||
errorf("unmarshal error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
if e.Code != code {
|
||||
errorf("Code = %q; want %q", e.Code, code)
|
||||
}
|
||||
if !strings.Contains(e.Message, msg) {
|
||||
errorf("Message = %q; want to contain %q", e.Message, msg)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
1
server/internal/registry/testdata/models/manifests/example.com/library/smol/latest
vendored
Normal file
1
server/internal/registry/testdata/models/manifests/example.com/library/smol/latest
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-- v2/library/smol/manifests/latest --
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {
|
||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
||||
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
|
||||
"size": 3
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.model",
|
||||
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
|
||||
"size": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
|
||||
GGUF
|
||||
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
|
||||
{}
|
||||
102
server/internal/testutil/testutil.go
Normal file
102
server/internal/testutil/testutil.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
|
||||
func LogWriter(t *testing.T) io.Writer {
|
||||
return testWriter{t}
|
||||
}
|
||||
|
||||
type testWriter struct{ t *testing.T }
|
||||
|
||||
func (w testWriter) Write(b []byte) (int, error) {
|
||||
w.t.Logf("%s", b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Slogger returns a [*slog.Logger] that writes each message
|
||||
// using t.Log.
|
||||
func Slogger(t *testing.T) *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
|
||||
}
|
||||
|
||||
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
|
||||
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
|
||||
var buf bytes.Buffer
|
||||
lg = slog.New(slog.NewTextHandler(&buf, nil))
|
||||
return lg, &buf
|
||||
}
|
||||
|
||||
// Check calls t.Fatal(err) if err is not nil.
|
||||
func Check(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckFunc exists so other packages do not need to invent their own type for
|
||||
// taking a Check function.
|
||||
type CheckFunc func(err error)
|
||||
|
||||
// Checker returns a check function that
|
||||
// calls t.Fatal if err is not nil.
|
||||
func Checker(t *testing.T) (check func(err error)) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopPanic runs f but silently recovers from any panic f causes.
|
||||
// The normal usage is:
|
||||
//
|
||||
// testutil.StopPanic(func() {
|
||||
// callThatShouldPanic()
|
||||
// t.Errorf("callThatShouldPanic did not panic")
|
||||
// })
|
||||
func StopPanic(f func()) {
|
||||
defer func() { recover() }()
|
||||
f()
|
||||
}
|
||||
|
||||
// CheckTime calls t.Fatalf if got != want. Included in the error message is
|
||||
// want.Sub(got) to help diagnose the difference, along with their values in
|
||||
// UTC.
|
||||
func CheckTime(t *testing.T, got, want time.Time) {
|
||||
t.Helper()
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
|
||||
}
|
||||
}
|
||||
|
||||
// WriteFile writes data to a file named name. It makes the directory if it
|
||||
// doesn't exist and sets the file mode to perm.
|
||||
//
|
||||
// The name must be a relative path and must not contain .. or start with a /;
|
||||
// otherwise WriteFile will panic.
|
||||
func WriteFile[S []byte | string](t testing.TB, name string, data S) {
|
||||
t.Helper()
|
||||
|
||||
if filepath.IsAbs(name) {
|
||||
t.Fatalf("WriteFile: name must be a relative path, got %q", name)
|
||||
}
|
||||
name = filepath.Clean(name)
|
||||
dir := filepath.Dir(name)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
90
server/laguna_quantization_test.go
Normal file
90
server/laguna_quantization_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func TestLagunaGGUFQuantization(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
tensor string
|
||||
originalType fsggml.TensorType
|
||||
requestedType fsggml.TensorType
|
||||
fileType fsggml.FileType
|
||||
blockCount int
|
||||
wantType fsggml.TensorType
|
||||
wantQuantize bool
|
||||
}{
|
||||
{
|
||||
name: "non_routed_weights_preserved",
|
||||
tensor: "blk.1.attn_q.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ8_0,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeBF16,
|
||||
wantQuantize: false,
|
||||
},
|
||||
{
|
||||
name: "shared_expert_weights_preserved",
|
||||
tensor: "blk.1.ffn_gate_shexp.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeBF16,
|
||||
wantQuantize: false,
|
||||
},
|
||||
{
|
||||
name: "routed_gate_q8",
|
||||
tensor: "blk.1.ffn_gate_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ8_0,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ8_0,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_promoted",
|
||||
tensor: "blk.1.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ6_K,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_not_promoted_when_q8_requested",
|
||||
tensor: "blk.1.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ8_0,
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
blockCount: 2,
|
||||
wantType: fsggml.TensorTypeQ8_0,
|
||||
wantQuantize: true,
|
||||
},
|
||||
{
|
||||
name: "routed_down_q4_k_s_promoted",
|
||||
tensor: "blk.0.ffn_down_exps.weight",
|
||||
originalType: fsggml.TensorTypeBF16,
|
||||
requestedType: fsggml.TensorTypeQ4_K,
|
||||
fileType: fsggml.FileTypeQ4_K_S,
|
||||
blockCount: 8,
|
||||
wantType: fsggml.TensorTypeQ5_K,
|
||||
wantQuantize: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotType, gotQuantize := lagunaGGUFQuantization(tt.tensor, tt.originalType, tt.requestedType, tt.fileType, tt.blockCount)
|
||||
if gotType != tt.wantType || gotQuantize != tt.wantQuantize {
|
||||
t.Fatalf("lagunaGGUFQuantization(%q) = (%s, %v), want (%s, %v)", tt.tensor, gotType, gotQuantize, tt.wantType, tt.wantQuantize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
44
server/logprob.go
Normal file
44
server/logprob.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
// toAPILogprobs converts llm.Logprobs to api.Logprobs
|
||||
func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
|
||||
result := make([]api.Logprob, len(logprobs))
|
||||
for i, lp := range logprobs {
|
||||
result[i] = api.Logprob{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: lp.Token,
|
||||
Bytes: stringToByteInts(lp.Token),
|
||||
Logprob: lp.Logprob,
|
||||
},
|
||||
}
|
||||
if len(lp.TopLogprobs) > 0 {
|
||||
result[i].TopLogprobs = make([]api.TokenLogprob, len(lp.TopLogprobs))
|
||||
for j, tlp := range lp.TopLogprobs {
|
||||
result[i].TopLogprobs[j] = api.TokenLogprob{
|
||||
Token: tlp.Token,
|
||||
Bytes: stringToByteInts(tlp.Token),
|
||||
Logprob: tlp.Logprob,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func stringToByteInts(s string) []int {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
raw := []byte(s)
|
||||
ints := make([]int, len(raw))
|
||||
for i, b := range raw {
|
||||
ints[i] = int(b)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
129
server/model.go
Normal file
129
server/model.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
manifest.Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, srcLayer := range m.Layers {
|
||||
layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer.Name = srcLayer.Name
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := os.Open(blobpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
f, err := ggml.Decode(blob, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
default:
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
}
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
for _, layer := range layers {
|
||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func detectContentType(r io.Reader) (string, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if contentType := ggml.DetectContentType(b.Bytes()); contentType != "" {
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
return "unknown", nil
|
||||
}
|
||||
32
server/model_caches.go
Normal file
32
server/model_caches.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package server
|
||||
|
||||
import "context"
|
||||
|
||||
type modelCaches struct {
|
||||
recommendations *modelRecommendationsCache
|
||||
show *modelShowCache
|
||||
modelList *modelListCache
|
||||
}
|
||||
|
||||
func newModelCaches() *modelCaches {
|
||||
return &modelCaches{
|
||||
recommendations: newModelRecommendationsCache(),
|
||||
show: newModelShowCache(),
|
||||
modelList: newModelListCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelCaches) Start(ctx context.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if c.recommendations != nil {
|
||||
c.recommendations.Start(ctx)
|
||||
}
|
||||
if c.show != nil {
|
||||
c.show.Start(ctx)
|
||||
}
|
||||
if c.modelList != nil {
|
||||
c.modelList.Start(ctx)
|
||||
}
|
||||
}
|
||||
824
server/model_list_cache.go
Normal file
824
server/model_list_cache.go
Normal file
@@ -0,0 +1,824 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
ollamatemplate "github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelListSummary struct {
|
||||
Model string
|
||||
Name string
|
||||
RemoteModel string
|
||||
RemoteHost string
|
||||
Size int64
|
||||
Digest string
|
||||
ModifiedAt time.Time
|
||||
Details api.ModelDetails
|
||||
Capabilities []model.Capability
|
||||
}
|
||||
|
||||
type modelListCacheEntry struct {
|
||||
Digest string
|
||||
Summary modelListSummary
|
||||
}
|
||||
|
||||
type modelListCache struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
entries map[string]modelListCacheEntry
|
||||
|
||||
once sync.Once
|
||||
readyOnce sync.Once
|
||||
ready chan struct{}
|
||||
hydrateErr error
|
||||
build func(model.Name, *manifest.Manifest) (modelListSummary, error)
|
||||
}
|
||||
|
||||
func newModelListCache() *modelListCache {
|
||||
return &modelListCache{
|
||||
entries: make(map[string]modelListCacheEntry),
|
||||
ready: make(chan struct{}),
|
||||
build: buildModelListSummary,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelListCache) Start(ctx context.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.once.Do(func() {
|
||||
slog.Debug("starting model list cache")
|
||||
go func() {
|
||||
err := c.hydrate(ctx)
|
||||
c.markReady(err)
|
||||
if err != nil {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
slog.Warn("model list cache hydration failed", "error", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *modelListCache) hydrate(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var hydrated, failed int
|
||||
for name, mf := range manifests {
|
||||
if ctx != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
summary, err := c.build(name, mf)
|
||||
if err != nil {
|
||||
failed++
|
||||
slog.Warn("failed to hydrate model list cache", "model", name.String(), "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.set(name, mf.Digest(), summary)
|
||||
hydrated++
|
||||
}
|
||||
|
||||
slog.Info("model list cache hydration complete", "models", hydrated, "failures", failed, "elapsed", time.Since(start))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *modelListCache) markReady(err error) {
|
||||
c.mu.Lock()
|
||||
c.hydrateErr = err
|
||||
c.mu.Unlock()
|
||||
|
||||
c.readyOnce.Do(func() {
|
||||
close(c.ready)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *modelListCache) Wait(ctx context.Context) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.ready:
|
||||
c.mu.RLock()
|
||||
err := c.hydrateErr
|
||||
c.mu.RUnlock()
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelListCache) List(ctx context.Context) ([]api.ListModelResponse, error) {
|
||||
if err := c.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.syncManifests(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
models := make([]api.ListModelResponse, 0, len(c.entries))
|
||||
for _, entry := range c.entries {
|
||||
models = append(models, entry.Summary.ListModelResponse())
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
sortListModelResponses(models)
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *modelListCache) syncManifests(ctx context.Context) error {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
current := make(map[string]string, len(c.entries))
|
||||
for name, entry := range c.entries {
|
||||
current[name] = entry.Digest
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
type update struct {
|
||||
name model.Name
|
||||
digest string
|
||||
summary modelListSummary
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(manifests))
|
||||
stale := make(map[string]struct{})
|
||||
var updates []update
|
||||
for name, mf := range manifests {
|
||||
if ctx != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
key := name.String()
|
||||
digest := mf.Digest()
|
||||
seen[key] = struct{}{}
|
||||
if current[key] == digest {
|
||||
continue
|
||||
}
|
||||
|
||||
summary, err := c.build(name, mf)
|
||||
if err != nil {
|
||||
slog.Warn("failed to refresh model list cache", "model", key, "error", err)
|
||||
if _, ok := current[key]; ok {
|
||||
stale[key] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
updates = append(updates, update{name: name, digest: digest, summary: summary})
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
for name := range c.entries {
|
||||
if _, ok := seen[name]; !ok {
|
||||
delete(c.entries, name)
|
||||
continue
|
||||
}
|
||||
if _, ok := stale[name]; ok {
|
||||
delete(c.entries, name)
|
||||
}
|
||||
}
|
||||
for _, update := range updates {
|
||||
c.entries[update.name.String()] = modelListCacheEntry{
|
||||
Digest: update.digest,
|
||||
Summary: cloneModelListSummary(update.summary),
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *modelListCache) RefreshModel(name model.Name) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !name.IsFullyQualified() {
|
||||
var err error
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
c.DeleteModel(name)
|
||||
return err
|
||||
}
|
||||
|
||||
summary, err := c.build(name, mf)
|
||||
if err != nil {
|
||||
c.DeleteModel(name)
|
||||
return err
|
||||
}
|
||||
|
||||
c.set(name, mf.Digest(), summary)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *modelListCache) DeleteModel(name model.Name) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.entries, name.String())
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelListCache) Get(name model.Name) (modelListSummary, bool) {
|
||||
if c == nil {
|
||||
return modelListSummary{}, false
|
||||
}
|
||||
|
||||
if !name.IsFullyQualified() {
|
||||
if existing, err := getExistingName(name); err == nil {
|
||||
name = existing
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
entry, ok := c.entries[name.String()]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return modelListSummary{}, false
|
||||
}
|
||||
|
||||
return cloneModelListSummary(entry.Summary), true
|
||||
}
|
||||
|
||||
func (c *modelListCache) Len() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
|
||||
func (c *modelListCache) set(name model.Name, digest string, summary modelListSummary) {
|
||||
c.mu.Lock()
|
||||
c.entries[name.String()] = modelListCacheEntry{
|
||||
Digest: digest,
|
||||
Summary: cloneModelListSummary(summary),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func buildModelListSummary(name model.Name, mf *manifest.Manifest) (modelListSummary, error) {
|
||||
cfg, err := readModelListConfig(mf)
|
||||
if err != nil {
|
||||
return modelListSummary{}, err
|
||||
}
|
||||
|
||||
var modified time.Time
|
||||
if fi := mf.FileInfo(); fi != nil {
|
||||
modified = fi.ModTime()
|
||||
}
|
||||
|
||||
summary := modelListSummary{
|
||||
Model: name.DisplayShortest(),
|
||||
Name: name.DisplayShortest(),
|
||||
RemoteModel: cfg.RemoteModel,
|
||||
RemoteHost: cfg.RemoteHost,
|
||||
Size: mf.Size(),
|
||||
Digest: mf.Digest(),
|
||||
ModifiedAt: modified,
|
||||
Details: api.ModelDetails{
|
||||
Format: cfg.ModelFormat,
|
||||
Family: cfg.ModelFamily,
|
||||
Families: append([]string(nil), cfg.ModelFamilies...),
|
||||
ParameterSize: cfg.ModelType,
|
||||
QuantizationLevel: cfg.FileType,
|
||||
ContextLength: cfg.ContextLen,
|
||||
EmbeddingLength: cfg.EmbedLen,
|
||||
},
|
||||
}
|
||||
|
||||
modelPath, projectorCount, tmpl, err := readModelListLayers(mf, &summary)
|
||||
if err != nil {
|
||||
return modelListSummary{}, err
|
||||
}
|
||||
|
||||
if cfg.RemoteHost == "" && cfg.RemoteModel == "" && modelPath != "" {
|
||||
info, err := readModelListGGUF(modelPath)
|
||||
if err != nil {
|
||||
slog.Debug("failed to read gguf model metadata", "model", name.String(), "error", err)
|
||||
} else {
|
||||
summary.Capabilities = appendModelListCapabilities(summary.Capabilities, info.Capabilities...)
|
||||
if summary.Details.ContextLength == 0 {
|
||||
summary.Details.ContextLength = info.ContextLength
|
||||
}
|
||||
if summary.Details.EmbeddingLength == 0 {
|
||||
summary.Details.EmbeddingLength = info.EmbeddingLength
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range cfg.Capabilities {
|
||||
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.Capability(c))
|
||||
}
|
||||
|
||||
builtinParser := parsers.ParserForName(cfg.Parser)
|
||||
if tmpl != nil {
|
||||
vars, err := tmpl.Vars()
|
||||
if err != nil {
|
||||
slog.Warn("model template contains errors", "model", name.String(), "error", err)
|
||||
}
|
||||
if slices.Contains(vars, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
||||
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityTools)
|
||||
}
|
||||
if slices.Contains(vars, "suffix") {
|
||||
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
openingTag, closingTag := thinking.InferTags(tmpl.Template)
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, cfg.ModelFamily)
|
||||
if !slices.Contains(summary.Capabilities, model.CapabilityThinking) &&
|
||||
(hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport())) {
|
||||
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityThinking)
|
||||
}
|
||||
}
|
||||
|
||||
if projectorCount > 0 {
|
||||
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
if cfg.ModelFormat == "safetensors" && isGemma4Renderer(cfg.Renderer) {
|
||||
summary.Capabilities = slices.DeleteFunc(summary.Capabilities, func(c model.Capability) bool {
|
||||
return c == model.CapabilityVision || c == model.CapabilityAudio
|
||||
})
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func readModelListConfig(mf *manifest.Manifest) (model.ConfigV2, error) {
|
||||
var cfg model.ConfigV2
|
||||
if mf == nil || mf.Config.Digest == "" {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
f, err := mf.Config.Open()
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func readModelListLayers(mf *manifest.Manifest, summary *modelListSummary) (string, int, *ollamatemplate.Template, error) {
|
||||
var modelPath string
|
||||
var projectorCount int
|
||||
tmpl := ollamatemplate.DefaultTemplate
|
||||
|
||||
for _, layer := range mf.Layers {
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return "", 0, nil, err
|
||||
}
|
||||
modelPath = filename
|
||||
summary.Details.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.projector":
|
||||
projectorCount++
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return "", 0, nil, err
|
||||
}
|
||||
bts, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return "", 0, nil, err
|
||||
}
|
||||
|
||||
tmpl, err = ollamatemplate.Parse(string(bts))
|
||||
if err != nil {
|
||||
return "", 0, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath, projectorCount, tmpl, nil
|
||||
}
|
||||
|
||||
type modelListGGUF struct {
|
||||
Capabilities []model.Capability
|
||||
ContextLength int
|
||||
EmbeddingLength int
|
||||
}
|
||||
|
||||
const (
|
||||
modelListGGUFMagicLE = 0x46554747
|
||||
modelListGGUFMagicBE = 0x47475546
|
||||
)
|
||||
|
||||
const (
|
||||
modelListGGUFTypeUint8 uint32 = iota
|
||||
modelListGGUFTypeInt8
|
||||
modelListGGUFTypeUint16
|
||||
modelListGGUFTypeInt16
|
||||
modelListGGUFTypeUint32
|
||||
modelListGGUFTypeInt32
|
||||
modelListGGUFTypeFloat32
|
||||
modelListGGUFTypeBool
|
||||
modelListGGUFTypeString
|
||||
modelListGGUFTypeArray
|
||||
modelListGGUFTypeUint64
|
||||
modelListGGUFTypeInt64
|
||||
modelListGGUFTypeFloat64
|
||||
)
|
||||
|
||||
// readModelListGGUF scans only the small GGUF header values launch needs
|
||||
// and stops before tokenizer arrays. Using gguf.File.KeyValue for missing keys
|
||||
// can otherwise advance through large arrays just to discover absence.
|
||||
func readModelListGGUF(path string) (modelListGGUF, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
r := bufio.NewReaderSize(f, 32<<10)
|
||||
var magic uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
|
||||
var byteOrder binary.ByteOrder = binary.LittleEndian
|
||||
switch magic {
|
||||
case modelListGGUFMagicLE:
|
||||
case modelListGGUFMagicBE:
|
||||
byteOrder = binary.BigEndian
|
||||
default:
|
||||
return modelListGGUF{}, fmt.Errorf("invalid file magic")
|
||||
}
|
||||
|
||||
var version uint32
|
||||
if err := binary.Read(r, byteOrder, &version); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
|
||||
var numKV uint64
|
||||
switch version {
|
||||
case 1:
|
||||
var header struct {
|
||||
NumTensor uint32
|
||||
NumKV uint32
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &header); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
numKV = uint64(header.NumKV)
|
||||
default:
|
||||
var header struct {
|
||||
NumTensor uint64
|
||||
NumKV uint64
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &header); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
numKV = header.NumKV
|
||||
}
|
||||
|
||||
info := modelListGGUF{}
|
||||
var architecture string
|
||||
var hasPoolingType bool
|
||||
|
||||
for range numKV {
|
||||
key, err := readModelListGGUFString(r, byteOrder, version)
|
||||
if err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
|
||||
var valueType uint32
|
||||
if err := binary.Read(r, byteOrder, &valueType); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
|
||||
if key == "general.architecture" {
|
||||
value, err := readModelListGGUFStringValue(r, byteOrder, version, valueType)
|
||||
if err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
architecture = value
|
||||
continue
|
||||
}
|
||||
|
||||
if architecture != "" && strings.HasPrefix(key, "tokenizer.") {
|
||||
break
|
||||
}
|
||||
|
||||
if architecture != "" && strings.HasPrefix(key, architecture+".") {
|
||||
switch strings.TrimPrefix(key, architecture+".") {
|
||||
case "pooling_type":
|
||||
hasPoolingType = true
|
||||
case "vision.block_count":
|
||||
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityVision)
|
||||
case "audio.block_count":
|
||||
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityAudio)
|
||||
case "context_length":
|
||||
value, err := readModelListGGUFIntValue(r, byteOrder, version, valueType)
|
||||
if err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
info.ContextLength = value
|
||||
continue
|
||||
case "embedding_length":
|
||||
value, err := readModelListGGUFIntValue(r, byteOrder, version, valueType)
|
||||
if err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
info.EmbeddingLength = value
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
|
||||
return modelListGGUF{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if hasPoolingType {
|
||||
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func readModelListGGUFStringValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) (string, error) {
|
||||
if valueType != modelListGGUFTypeString {
|
||||
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", fmt.Errorf("unexpected gguf string type %d", valueType)
|
||||
}
|
||||
return readModelListGGUFString(r, byteOrder, version)
|
||||
}
|
||||
|
||||
func readModelListGGUFIntValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) (int, error) {
|
||||
switch valueType {
|
||||
case modelListGGUFTypeUint8:
|
||||
var value uint8
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeInt8:
|
||||
var value int8
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeUint16:
|
||||
var value uint16
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeInt16:
|
||||
var value int16
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeUint32:
|
||||
var value uint32
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeInt32:
|
||||
var value int32
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeUint64:
|
||||
var value uint64
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
case modelListGGUFTypeInt64:
|
||||
var value int64
|
||||
if err := binary.Read(r, byteOrder, &value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(value), nil
|
||||
default:
|
||||
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, fmt.Errorf("unexpected gguf integer type %d", valueType)
|
||||
}
|
||||
}
|
||||
|
||||
func skipModelListGGUFValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) error {
|
||||
switch valueType {
|
||||
case modelListGGUFTypeUint8, modelListGGUFTypeInt8, modelListGGUFTypeBool:
|
||||
return discardModelListGGUFBytes(r, 1)
|
||||
case modelListGGUFTypeUint16, modelListGGUFTypeInt16:
|
||||
return discardModelListGGUFBytes(r, 2)
|
||||
case modelListGGUFTypeUint32, modelListGGUFTypeInt32, modelListGGUFTypeFloat32:
|
||||
return discardModelListGGUFBytes(r, 4)
|
||||
case modelListGGUFTypeUint64, modelListGGUFTypeInt64, modelListGGUFTypeFloat64:
|
||||
return discardModelListGGUFBytes(r, 8)
|
||||
case modelListGGUFTypeString:
|
||||
return skipModelListGGUFString(r, byteOrder, version)
|
||||
case modelListGGUFTypeArray:
|
||||
var arrayType uint32
|
||||
if err := binary.Read(r, byteOrder, &arrayType); err != nil {
|
||||
return err
|
||||
}
|
||||
var count uint64
|
||||
if err := binary.Read(r, byteOrder, &count); err != nil {
|
||||
return err
|
||||
}
|
||||
return skipModelListGGUFArray(r, byteOrder, version, arrayType, count)
|
||||
default:
|
||||
return fmt.Errorf("unsupported gguf value type %d", valueType)
|
||||
}
|
||||
}
|
||||
|
||||
func skipModelListGGUFArray(r io.Reader, byteOrder binary.ByteOrder, version uint32, arrayType uint32, count uint64) error {
|
||||
var size uint64
|
||||
switch arrayType {
|
||||
case modelListGGUFTypeUint8, modelListGGUFTypeInt8, modelListGGUFTypeBool:
|
||||
size = 1
|
||||
case modelListGGUFTypeUint16, modelListGGUFTypeInt16:
|
||||
size = 2
|
||||
case modelListGGUFTypeUint32, modelListGGUFTypeInt32, modelListGGUFTypeFloat32:
|
||||
size = 4
|
||||
case modelListGGUFTypeUint64, modelListGGUFTypeInt64, modelListGGUFTypeFloat64:
|
||||
size = 8
|
||||
case modelListGGUFTypeString:
|
||||
for range count {
|
||||
if err := skipModelListGGUFString(r, byteOrder, version); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported gguf array type %d", arrayType)
|
||||
}
|
||||
return discardModelListGGUFBytes(r, int64(count*size))
|
||||
}
|
||||
|
||||
func readModelListGGUFString(r io.Reader, byteOrder binary.ByteOrder, version uint32) (string, error) {
|
||||
var length uint64
|
||||
if err := binary.Read(r, byteOrder, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
bts := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if version == 1 && bts[len(bts)-1] == 0 {
|
||||
bts = bts[:len(bts)-1]
|
||||
}
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func skipModelListGGUFString(r io.Reader, byteOrder binary.ByteOrder, version uint32) error {
|
||||
var length uint64
|
||||
if err := binary.Read(r, byteOrder, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
return discardModelListGGUFBytes(r, int64(length))
|
||||
}
|
||||
|
||||
func discardModelListGGUFBytes(r io.Reader, n int64) error {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := io.CopyN(io.Discard, r, n)
|
||||
return err
|
||||
}
|
||||
|
||||
func appendModelListCapabilities(capabilities []model.Capability, values ...model.Capability) []model.Capability {
|
||||
for _, capability := range values {
|
||||
capabilities = appendModelListCapability(capabilities, capability)
|
||||
}
|
||||
return capabilities
|
||||
}
|
||||
|
||||
func appendModelListCapability(capabilities []model.Capability, capability model.Capability) []model.Capability {
|
||||
if capability == "" || slices.Contains(capabilities, capability) {
|
||||
return capabilities
|
||||
}
|
||||
return append(capabilities, capability)
|
||||
}
|
||||
|
||||
func cloneModelListSummary(summary modelListSummary) modelListSummary {
|
||||
summary.Details.Families = append([]string(nil), summary.Details.Families...)
|
||||
summary.Capabilities = append([]model.Capability(nil), summary.Capabilities...)
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s modelListSummary) ListModelResponse() api.ListModelResponse {
|
||||
resp := api.ListModelResponse{
|
||||
Model: s.Model,
|
||||
Name: s.Name,
|
||||
RemoteModel: s.RemoteModel,
|
||||
RemoteHost: s.RemoteHost,
|
||||
Size: s.Size,
|
||||
Digest: s.Digest,
|
||||
ModifiedAt: s.ModifiedAt,
|
||||
Details: api.ModelDetails{
|
||||
ParentModel: s.Details.ParentModel,
|
||||
Format: s.Details.Format,
|
||||
Family: s.Details.Family,
|
||||
Families: append([]string(nil), s.Details.Families...),
|
||||
ParameterSize: s.Details.ParameterSize,
|
||||
QuantizationLevel: s.Details.QuantizationLevel,
|
||||
ContextLength: s.Details.ContextLength,
|
||||
EmbeddingLength: s.Details.EmbeddingLength,
|
||||
},
|
||||
}
|
||||
|
||||
resp.Capabilities = append([]model.Capability(nil), s.Capabilities...)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func sortListModelResponses(models []api.ListModelResponse) {
|
||||
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||
// Preserve the existing /api/tags order: most recently modified first.
|
||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) refreshModelListCache(name model.Name) {
|
||||
if s == nil || s.modelCaches == nil || s.modelCaches.modelList == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.modelCaches.modelList.RefreshModel(name); err != nil {
|
||||
slog.Warn("failed to refresh model list cache", "model", name.String(), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) deleteModelListCache(name model.Name) {
|
||||
if s == nil || s.modelCaches == nil || s.modelCaches.modelList == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.modelCaches.modelList.DeleteModel(name)
|
||||
}
|
||||
239
server/model_list_cache_test.go
Normal file
239
server/model_list_cache_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestModelListCacheHydratesSummary(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createListCacheModel(t, "list-cache", map[string]any{
|
||||
"test.context_length": uint32(4096),
|
||||
"test.embedding_length": uint32(384),
|
||||
}, "{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
|
||||
cache := newModelListCache()
|
||||
if err := cache.hydrate(context.Background()); err != nil {
|
||||
t.Fatalf("hydrate failed: %v", err)
|
||||
}
|
||||
|
||||
summary, ok := cache.Get(model.ParseName("list-cache"))
|
||||
if !ok {
|
||||
t.Fatal("list summary missing")
|
||||
}
|
||||
|
||||
if summary.Model != "list-cache:latest" || summary.Name != "list-cache:latest" {
|
||||
t.Fatalf("summary model/name = %q/%q, want list-cache:latest", summary.Model, summary.Name)
|
||||
}
|
||||
if summary.Digest == "" {
|
||||
t.Fatal("summary digest is empty")
|
||||
}
|
||||
if summary.Size == 0 {
|
||||
t.Fatal("summary size is zero")
|
||||
}
|
||||
if summary.Details.Family != "test" || summary.Details.Format != "gguf" {
|
||||
t.Fatalf("summary details = %+v, want gguf/test", summary.Details)
|
||||
}
|
||||
if summary.Details.ContextLength != 4096 {
|
||||
t.Fatalf("context length = %d, want 4096", summary.Details.ContextLength)
|
||||
}
|
||||
if summary.Details.EmbeddingLength != 384 {
|
||||
t.Fatalf("embedding length = %d, want 384", summary.Details.EmbeddingLength)
|
||||
}
|
||||
|
||||
for _, capability := range []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert} {
|
||||
if !slices.Contains(summary.Capabilities, capability) {
|
||||
t.Fatalf("capabilities = %v, want %s", summary.Capabilities, capability)
|
||||
}
|
||||
}
|
||||
|
||||
listModel := summary.ListModelResponse()
|
||||
if !slices.Contains(listModel.Capabilities, model.CapabilityTools) ||
|
||||
listModel.Details.ContextLength != 4096 ||
|
||||
listModel.Details.EmbeddingLength != 384 {
|
||||
t.Fatalf("list response = %+v, want capabilities/context/embedding", listModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelListCacheRefreshUpdatesEntry(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createListCacheModel(t, "list-refresh", map[string]any{"test.context_length": uint32(1024)}, "")
|
||||
|
||||
cache := newModelListCache()
|
||||
if err := cache.hydrate(context.Background()); err != nil {
|
||||
t.Fatalf("hydrate failed: %v", err)
|
||||
}
|
||||
|
||||
name := model.ParseName("list-refresh")
|
||||
first, ok := cache.Get(name)
|
||||
if !ok {
|
||||
t.Fatal("list summary missing")
|
||||
}
|
||||
|
||||
changeShowCacheManifest(t, "list-refresh")
|
||||
if err := cache.RefreshModel(name); err != nil {
|
||||
t.Fatalf("refresh failed: %v", err)
|
||||
}
|
||||
|
||||
refreshed, ok := cache.Get(name)
|
||||
if !ok {
|
||||
t.Fatal("refreshed list summary missing")
|
||||
}
|
||||
if refreshed.Digest == first.Digest {
|
||||
t.Fatalf("digest did not change after refresh: %s", refreshed.Digest)
|
||||
}
|
||||
if cache.Len() != 1 {
|
||||
t.Fatalf("cache entries = %d, want 1", cache.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelListCacheMutationHooks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
cache := newModelListCache()
|
||||
s := Server{modelCaches: &modelCaches{modelList: cache}}
|
||||
|
||||
_, digest := createBinFile(t, map[string]any{"test.context_length": uint32(2048)}, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "list-hooks",
|
||||
Files: map[string]string{"model.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if _, ok := cache.Get(model.ParseName("list-hooks")); !ok {
|
||||
t.Fatal("create did not refresh model list cache")
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CopyHandler, api.CopyRequest{
|
||||
Source: "list-hooks",
|
||||
Destination: "list-hooks-copy",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("copy model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if _, ok := cache.Get(model.ParseName("list-hooks-copy")); !ok {
|
||||
t.Fatal("copy did not refresh model list cache")
|
||||
}
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Model: "list-hooks-copy"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("delete model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if _, ok := cache.Get(model.ParseName("list-hooks-copy")); ok {
|
||||
t.Fatal("delete did not remove model list cache entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelListCacheSyncsManifestChanges(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createListCacheModel(t, "list-sync-a", map[string]any{"test.context_length": uint32(1024)}, "")
|
||||
|
||||
cache := newModelListCache()
|
||||
cache.Start(context.Background())
|
||||
if err := cache.Wait(context.Background()); err != nil {
|
||||
t.Fatalf("wait failed: %v", err)
|
||||
}
|
||||
|
||||
createListCacheModel(t, "list-sync-b", map[string]any{"test.context_length": uint32(2048)}, "")
|
||||
models, err := cache.List(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("list failed: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
names = append(names, m.Name)
|
||||
}
|
||||
for _, want := range []string{"list-sync-a:latest", "list-sync-b:latest"} {
|
||||
if !slices.Contains(names, want) {
|
||||
t.Fatalf("names = %v, want %s", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
var other Server
|
||||
w := createRequest(t, other.DeleteHandler, api.DeleteRequest{Model: "list-sync-a"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("delete model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
models, err = cache.List(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("list after delete failed: %v", err)
|
||||
}
|
||||
names = names[:0]
|
||||
for _, m := range models {
|
||||
names = append(names, m.Name)
|
||||
}
|
||||
if slices.Contains(names, "list-sync-a:latest") || !slices.Contains(names, "list-sync-b:latest") {
|
||||
t.Fatalf("names after delete = %v, want only list-sync-b", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelListCacheSyncDropsStaleEntryOnRefreshFailure(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createListCacheModel(t, "list-stale", map[string]any{"test.context_length": uint32(1024)}, "")
|
||||
|
||||
cache := newModelListCache()
|
||||
cache.Start(context.Background())
|
||||
if err := cache.Wait(context.Background()); err != nil {
|
||||
t.Fatalf("wait failed: %v", err)
|
||||
}
|
||||
|
||||
name := model.ParseName("list-stale")
|
||||
if _, ok := cache.Get(name); !ok {
|
||||
t.Fatal("list summary missing")
|
||||
}
|
||||
|
||||
changeShowCacheManifest(t, "list-stale")
|
||||
cache.build = func(model.Name, *manifest.Manifest) (modelListSummary, error) {
|
||||
return modelListSummary{}, errors.New("refresh failed")
|
||||
}
|
||||
|
||||
models, err := cache.List(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("list failed: %v", err)
|
||||
}
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("models = %+v, want stale entry removed", models)
|
||||
}
|
||||
if _, ok := cache.Get(name); ok {
|
||||
t.Fatal("stale entry remained in cache after refresh failure")
|
||||
}
|
||||
}
|
||||
|
||||
func createListCacheModel(t *testing.T, name string, kv map[string]any, tmpl string) {
|
||||
t.Helper()
|
||||
_, digest := createBinFile(t, kv, nil)
|
||||
|
||||
req := api.CreateRequest{
|
||||
Model: name,
|
||||
Files: map[string]string{"model.gguf": digest},
|
||||
Stream: &stream,
|
||||
}
|
||||
if tmpl != "" {
|
||||
req.Template = tmpl
|
||||
}
|
||||
|
||||
var s Server
|
||||
w := createRequest(t, s.CreateHandler, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
402
server/model_recommendations.go
Normal file
402
server/model_recommendations.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
|
||||
|
||||
var (
|
||||
modelRecommendationsRefreshInterval = 4 * time.Hour
|
||||
modelRecommendationsFetchTimeout = 3 * time.Second
|
||||
modelRecommendationsReadRefreshCooldown = 5 * time.Second
|
||||
modelRecommendationsBackoffSteps = []time.Duration{
|
||||
5 * time.Minute,
|
||||
15 * time.Minute,
|
||||
time.Hour,
|
||||
4 * time.Hour,
|
||||
}
|
||||
|
||||
errModelRecommendationsNoCloud = errors.New("cloud disabled")
|
||||
)
|
||||
|
||||
type modelRecommendationsCache struct {
|
||||
mu sync.RWMutex
|
||||
recommendations []api.ModelRecommendation
|
||||
refreshing bool
|
||||
nextReadRefreshAfter time.Time
|
||||
once sync.Once
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func newModelRecommendationsCache() *modelRecommendationsCache {
|
||||
return &modelRecommendationsCache{
|
||||
recommendations: cloneModelRecommendations(defaultModelRecommendations),
|
||||
client: http.DefaultClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) Start(ctx context.Context) {
|
||||
c.once.Do(func() {
|
||||
slog.Debug("starting model recommendations cache",
|
||||
"default_recommendations", len(defaultModelRecommendations),
|
||||
"refresh_interval", modelRecommendationsRefreshInterval.String(),
|
||||
"fetch_timeout", modelRecommendationsFetchTimeout.String(),
|
||||
)
|
||||
go c.run(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) Get() []api.ModelRecommendation {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return cloneModelRecommendations(c.recommendations)
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) GetSWR(ctx context.Context) []api.ModelRecommendation {
|
||||
recs := c.Get()
|
||||
c.triggerRefreshOnRead(ctx)
|
||||
return recs
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) set(recs []api.ModelRecommendation) {
|
||||
c.mu.Lock()
|
||||
c.recommendations = cloneModelRecommendations(recs)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) beginRefresh() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.refreshing {
|
||||
return false
|
||||
}
|
||||
c.refreshing = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) beginReadRefresh() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
if c.refreshing || now.Before(c.nextReadRefreshAfter) {
|
||||
return false
|
||||
}
|
||||
|
||||
c.refreshing = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) endRefresh() {
|
||||
c.mu.Lock()
|
||||
c.refreshing = false
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) endReadRefresh() {
|
||||
c.mu.Lock()
|
||||
c.refreshing = false
|
||||
c.nextReadRefreshAfter = time.Now().Add(modelRecommendationsReadRefreshCooldown)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) refreshIfIdle(ctx context.Context) (bool, error) {
|
||||
if !c.beginRefresh() {
|
||||
return false, nil
|
||||
}
|
||||
defer c.endRefresh()
|
||||
return true, c.refresh(ctx)
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) triggerRefreshOnRead(ctx context.Context) {
|
||||
if !c.beginReadRefresh() {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
|
||||
slog.Debug("triggering model recommendations refresh on read")
|
||||
go func() {
|
||||
defer c.endReadRefresh()
|
||||
|
||||
if err := c.refresh(ctx); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errModelRecommendationsNoCloud):
|
||||
slog.Debug("skipping model recommendations read refresh because cloud is disabled")
|
||||
default:
|
||||
slog.Warn("model recommendations read refresh failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) run(ctx context.Context) {
|
||||
c.loadSnapshot()
|
||||
|
||||
failures := 0
|
||||
for {
|
||||
started, err := c.refreshIfIdle(ctx)
|
||||
switch {
|
||||
case !started:
|
||||
failures = 0
|
||||
slog.Debug("skipping timer model recommendations refresh because refresh is already running")
|
||||
case err == nil:
|
||||
failures = 0
|
||||
case errors.Is(err, errModelRecommendationsNoCloud):
|
||||
failures = 0
|
||||
slog.Debug("skipping model recommendations refresh because cloud is disabled")
|
||||
default:
|
||||
failures++
|
||||
slog.Warn("model recommendations refresh failed", "error", err)
|
||||
}
|
||||
|
||||
var wait time.Duration
|
||||
if failures == 0 {
|
||||
wait = withJitter(modelRecommendationsRefreshInterval)
|
||||
} else {
|
||||
wait = withJitter(modelRecommendationsBackoffSteps[min(failures-1, len(modelRecommendationsBackoffSteps)-1)])
|
||||
}
|
||||
slog.Info("model recommendations cache sleep scheduled", "wait", wait.String(), "consecutive_failures", failures)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping model recommendations cache")
|
||||
return
|
||||
case <-time.After(wait):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) refresh(ctx context.Context) error {
|
||||
if envconfig.NoCloud() {
|
||||
return errModelRecommendationsNoCloud
|
||||
}
|
||||
slog.Debug("refreshing model recommendations from remote", "url", modelRecommendationsURL)
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelRecommendationsFetchTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelRecommendationsURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var payload api.ModelRecommendationsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recs, err := validateModelRecommendations(payload.Recommendations)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.set(recs)
|
||||
slog.Debug("model recommendations refreshed", "count", len(recs))
|
||||
if err := c.persistSnapshot(recs); err != nil {
|
||||
slog.Warn("failed to persist model recommendations snapshot", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) loadSnapshot() {
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
slog.Warn("failed to resolve model recommendations snapshot path", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
slog.Warn("failed to read model recommendations snapshot", "path", path, "error", err)
|
||||
} else {
|
||||
slog.Debug("model recommendations snapshot not found", "path", path)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var snap api.ModelRecommendationsResponse
|
||||
if err := json.Unmarshal(data, &snap); err != nil {
|
||||
slog.Warn("failed to parse model recommendations snapshot", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
recs, err := validateModelRecommendations(snap.Recommendations)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring invalid model recommendations snapshot", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.set(recs)
|
||||
slog.Debug("loaded model recommendations snapshot", "path", path, "count", len(recs))
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) persistSnapshot(recs []api.ModelRecommendation) error {
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := api.ModelRecommendationsResponse{Recommendations: recs}
|
||||
data, err := json.MarshalIndent(payload, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp(filepath.Dir(path), ".model-recommendations-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug("persisted model recommendations snapshot", "path", path, "count", len(recs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func modelRecommendationsSnapshotPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "cache", "model-recommendations.json"), nil
|
||||
}
|
||||
|
||||
func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRecommendation, error) {
|
||||
if len(recs) == 0 {
|
||||
return nil, errors.New("empty recommendations")
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(recs))
|
||||
valid := make([]api.ModelRecommendation, 0, len(recs))
|
||||
for _, rec := range recs {
|
||||
rec.Model = strings.TrimSpace(rec.Model)
|
||||
rec.Description = strings.TrimSpace(rec.Description)
|
||||
rec.RequiredPlan = strings.TrimSpace(rec.RequiredPlan)
|
||||
|
||||
if rec.Model == "" {
|
||||
return nil, errors.New("recommendation missing model")
|
||||
}
|
||||
if _, ok := seen[rec.Model]; ok {
|
||||
return nil, fmt.Errorf("duplicate recommendation %q", rec.Model)
|
||||
}
|
||||
seen[rec.Model] = struct{}{}
|
||||
|
||||
if isCloudRecommendation(rec.Model) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
|
||||
slog.Warn("dropping cloud recommendation missing limits", "model", rec.Model)
|
||||
continue
|
||||
}
|
||||
valid = append(valid, rec)
|
||||
}
|
||||
|
||||
if len(valid) == 0 {
|
||||
return nil, errors.New("no valid recommendations")
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
func isCloudRecommendation(modelName string) bool {
|
||||
return strings.HasSuffix(modelName, ":cloud") || strings.HasSuffix(modelName, "-cloud")
|
||||
}
|
||||
|
||||
func withJitter(d time.Duration) time.Duration {
|
||||
if d <= 0 {
|
||||
return d
|
||||
}
|
||||
// jitter in range [0.8x, 1.2x]
|
||||
factor := 0.8 + rand.Float64()*0.4
|
||||
return time.Duration(float64(d) * factor)
|
||||
}
|
||||
|
||||
func cloneModelRecommendations(in []api.ModelRecommendation) []api.ModelRecommendation {
|
||||
out := make([]api.ModelRecommendation, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
|
||||
var defaultModelRecommendations = []api.ModelRecommendation{
|
||||
{
|
||||
Model: "kimi-k2.6:cloud",
|
||||
Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability",
|
||||
ContextLength: 262_144,
|
||||
MaxOutputTokens: 262_144,
|
||||
},
|
||||
{
|
||||
Model: "glm-5.1:cloud",
|
||||
Description: "Reasoning and code generation",
|
||||
ContextLength: 202_752,
|
||||
MaxOutputTokens: 131_072,
|
||||
},
|
||||
{
|
||||
Model: "qwen3.5:cloud",
|
||||
Description: "Reasoning, coding, and agentic tool use with vision",
|
||||
ContextLength: 262_144,
|
||||
MaxOutputTokens: 32_768,
|
||||
},
|
||||
{
|
||||
Model: "minimax-m2.7:cloud",
|
||||
Description: "Fast, efficient coding and real-world productivity",
|
||||
ContextLength: 204_800,
|
||||
MaxOutputTokens: 128_000,
|
||||
},
|
||||
{
|
||||
Model: "gemma4",
|
||||
Description: "Reasoning and code generation locally",
|
||||
VRAMBytes: 12 * format.GigaByte,
|
||||
},
|
||||
{
|
||||
Model: "qwen3.5",
|
||||
Description: "Reasoning, coding, and visual understanding locally",
|
||||
VRAMBytes: 14 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
619
server/model_recommendations_test.go
Normal file
619
server/model_recommendations_test.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestModelRecommendationsDefaultOrder(t *testing.T) {
|
||||
want := []string{
|
||||
"kimi-k2.6:cloud",
|
||||
"glm-5.1:cloud",
|
||||
"qwen3.5:cloud",
|
||||
"minimax-m2.7:cloud",
|
||||
"gemma4",
|
||||
"qwen3.5",
|
||||
}
|
||||
|
||||
if got := modelRecommendationNames(defaultModelRecommendations); !slices.Equal(got, want) {
|
||||
t.Fatalf("recommendations = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshAppliesServerSideChanges(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
first := []api.ModelRecommendation{
|
||||
{Model: " first-cloud:cloud ", Description: " first ", ContextLength: 2048, MaxOutputTokens: 512},
|
||||
{Model: " first-local ", Description: " first local ", VRAMBytes: 3 * format.GigaByte},
|
||||
}
|
||||
second := []api.ModelRecommendation{
|
||||
{Model: "second-cloud:cloud", Description: "second", ContextLength: 4096, MaxOutputTokens: 1024},
|
||||
{Model: "second-local", Description: "second local", VRAMBytes: 6 * format.GigaByte},
|
||||
}
|
||||
|
||||
calls := 0
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Method != http.MethodGet {
|
||||
t.Fatalf("method = %q, want GET", req.Method)
|
||||
}
|
||||
if req.URL.String() != modelRecommendationsURL {
|
||||
t.Fatalf("url = %q, want %q", req.URL.String(), modelRecommendationsURL)
|
||||
}
|
||||
|
||||
calls++
|
||||
payload := api.ModelRecommendationsResponse{Recommendations: first}
|
||||
if calls > 1 {
|
||||
payload.Recommendations = second
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal payload failed: %v", err)
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, string(data)), nil
|
||||
})}
|
||||
|
||||
if err := cache.refresh(context.Background()); err != nil {
|
||||
t.Fatalf("first refresh failed: %v", err)
|
||||
}
|
||||
if got, want := cache.Get(), []api.ModelRecommendation{
|
||||
{Model: "first-cloud:cloud", Description: "first", ContextLength: 2048, MaxOutputTokens: 512},
|
||||
{Model: "first-local", Description: "first local", VRAMBytes: 3 * format.GigaByte},
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("after first refresh recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
|
||||
if err := cache.refresh(context.Background()); err != nil {
|
||||
t.Fatalf("second refresh failed: %v", err)
|
||||
}
|
||||
if got, want := cache.Get(), second; !slices.Equal(got, want) {
|
||||
t.Fatalf("after second refresh recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot path failed: %v", err)
|
||||
}
|
||||
snapshotData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read snapshot failed: %v", err)
|
||||
}
|
||||
var snapshot api.ModelRecommendationsResponse
|
||||
if err := json.Unmarshal(snapshotData, &snapshot); err != nil {
|
||||
t.Fatalf("unmarshal snapshot failed: %v", err)
|
||||
}
|
||||
if !slices.Equal(snapshot.Recommendations, second) {
|
||||
t.Fatalf("snapshot recommendations = %#v, want %#v", snapshot.Recommendations, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshErrorCasesPreserveCurrentData(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
transport roundTripFunc
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "transport error",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network down")
|
||||
},
|
||||
errSubstr: "network down",
|
||||
},
|
||||
{
|
||||
name: "remote status error",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusInternalServerError, "upstream broken"), nil
|
||||
},
|
||||
errSubstr: "status 500: upstream broken",
|
||||
},
|
||||
{
|
||||
name: "invalid json payload",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, "{"), nil
|
||||
},
|
||||
errSubstr: "unexpected EOF",
|
||||
},
|
||||
{
|
||||
name: "duplicate recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"dup","description":"a"},{"model":"dup","description":"b"}]}`), nil
|
||||
},
|
||||
errSubstr: `duplicate recommendation "dup"`,
|
||||
},
|
||||
{
|
||||
name: "empty recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[]}`), nil
|
||||
},
|
||||
errSubstr: "empty recommendations",
|
||||
},
|
||||
{
|
||||
name: "only invalid cloud recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"bad:cloud","description":"missing limits"}]}`), nil
|
||||
},
|
||||
errSubstr: "no valid recommendations",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
stable := []api.ModelRecommendation{{Model: "stable-local", Description: "stable desc", VRAMBytes: 2 * format.GigaByte}}
|
||||
cache.set(stable)
|
||||
cache.client = &http.Client{Transport: tc.transport}
|
||||
|
||||
err := cache.refresh(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("refresh returned nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errSubstr) {
|
||||
t.Fatalf("error = %q, want substring %q", err.Error(), tc.errSubstr)
|
||||
}
|
||||
|
||||
if got := cache.Get(); !slices.Equal(got, stable) {
|
||||
t.Fatalf("recommendations changed on error: got %#v, want %#v", got, stable)
|
||||
}
|
||||
|
||||
path, pathErr := modelRecommendationsSnapshotPath()
|
||||
if pathErr != nil {
|
||||
t.Fatalf("snapshot path failed: %v", pathErr)
|
||||
}
|
||||
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
|
||||
t.Fatalf("snapshot file should not be written on error, stat err = %v", statErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshNoCloudShortCircuits(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
called := false
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
called = true
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"should-not-be-used","description":"n/a"}]}`), nil
|
||||
})}
|
||||
|
||||
err := cache.refresh(context.Background())
|
||||
if !errors.Is(err, errModelRecommendationsNoCloud) {
|
||||
t.Fatalf("refresh error = %v, want %v", err, errModelRecommendationsNoCloud)
|
||||
}
|
||||
if called {
|
||||
t.Fatalf("remote endpoint should not be called when cloud is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsSnapshotPersistAndLoad(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
want := []api.ModelRecommendation{
|
||||
{Model: "persist-cloud:cloud", Description: "persisted", ContextLength: 8192, MaxOutputTokens: 2048},
|
||||
{Model: "persist-local", Description: "persisted local", VRAMBytes: 5 * format.GigaByte},
|
||||
}
|
||||
|
||||
writer := newModelRecommendationsCache()
|
||||
if err := writer.persistSnapshot(want); err != nil {
|
||||
t.Fatalf("persistSnapshot failed: %v", err)
|
||||
}
|
||||
|
||||
loader := newModelRecommendationsCache()
|
||||
loader.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
loader.loadSnapshot()
|
||||
|
||||
if got := loader.Get(); !slices.Equal(got, want) {
|
||||
t.Fatalf("loaded recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot path failed: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatalf("write invalid snapshot failed: %v", err)
|
||||
}
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
existing := []api.ModelRecommendation{{Model: "existing", Description: "existing description"}}
|
||||
cache.set(existing)
|
||||
cache.loadSnapshot()
|
||||
|
||||
if got := cache.Get(); !slices.Equal(got, existing) {
|
||||
t.Fatalf("recommendations overwritten by invalid snapshot: got %#v, want %#v", got, existing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
|
||||
input := []api.ModelRecommendation{
|
||||
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: " pro "},
|
||||
{Model: "bad-cloud:cloud", Description: "missing limits"},
|
||||
{Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte},
|
||||
}
|
||||
|
||||
got, err := validateModelRecommendations(input)
|
||||
if err != nil {
|
||||
t.Fatalf("validateModelRecommendations failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ModelRecommendation{
|
||||
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: "pro"},
|
||||
{Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte},
|
||||
}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("validated recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateModelRecommendationsDoesNotSynthesizeRequiredPlans(t *testing.T) {
|
||||
input := []api.ModelRecommendation{
|
||||
{Model: "kimi-k2.6:cloud", Description: "coding", ContextLength: 262_144, MaxOutputTokens: 262_144},
|
||||
{Model: "qwen3.5:cloud", Description: "reasoning", ContextLength: 262_144, MaxOutputTokens: 32_768},
|
||||
{Model: "custom:cloud", Description: "custom", ContextLength: 4096, MaxOutputTokens: 1024},
|
||||
{Model: "minimax-m2.7:cloud", Description: "custom", ContextLength: 204_800, MaxOutputTokens: 128_000, RequiredPlan: "team"},
|
||||
}
|
||||
|
||||
got, err := validateModelRecommendations(input)
|
||||
if err != nil {
|
||||
t.Fatalf("validateModelRecommendations failed: %v", err)
|
||||
}
|
||||
|
||||
byName := make(map[string]api.ModelRecommendation, len(got))
|
||||
for _, rec := range got {
|
||||
byName[rec.Model] = rec
|
||||
}
|
||||
|
||||
if rec := byName["kimi-k2.6:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("kimi required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["qwen3.5:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("qwen required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["custom:cloud"]; rec.RequiredPlan != "" {
|
||||
t.Fatalf("custom required plan should not be synthesized: %#v", rec)
|
||||
}
|
||||
if rec := byName["minimax-m2.7:cloud"]; rec.RequiredPlan != "team" {
|
||||
t.Fatalf("explicit required plan should not be overwritten: %#v", rec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
|
||||
s := &Server{}
|
||||
s.ModelRecommendationsExperimentalHandler(ctx)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
got := decodeRecommendationNames(t, w)
|
||||
want := modelRecommendationNames(defaultModelRecommendations)
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("models = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsHandlerUsesCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "test-model", Description: "test description"}})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
|
||||
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
|
||||
s.ModelRecommendationsExperimentalHandler(ctx)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
got := decodeRecommendationNames(t, w)
|
||||
if !slices.Equal(got, []string{"test-model"}) {
|
||||
t.Fatalf("models = %v, want %v", got, []string{"test-model"})
|
||||
}
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsRouteRegistration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "route-model", Description: "route description"}})
|
||||
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
|
||||
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRoutes failed: %v", err)
|
||||
}
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
getResp := httptest.NewRecorder()
|
||||
router.ServeHTTP(getResp, getReq)
|
||||
if getResp.Code != http.StatusOK {
|
||||
t.Fatalf("GET status = %d, want %d", getResp.Code, http.StatusOK)
|
||||
}
|
||||
if got := decodeRecommendationNames(t, getResp); !slices.Equal(got, []string{"route-model"}) {
|
||||
t.Fatalf("GET models = %v, want %v", got, []string{"route-model"})
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/api/experimental/model-recommendations", nil)
|
||||
postResp := httptest.NewRecorder()
|
||||
router.ServeHTTP(postResp, postReq)
|
||||
if postResp.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("POST status = %d, want %d", postResp.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRTriggersRefreshOnRead(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
|
||||
newRecs := []api.ModelRecommendation{{Model: "new-cloud:cloud", Description: "new", ContextLength: 1024, MaxOutputTokens: 256}}
|
||||
cache.set(old)
|
||||
|
||||
refreshDone := make(chan struct{})
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
defer close(refreshDone)
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"new-cloud:cloud","description":"new","context_length":1024,"max_output_tokens":256}]}`), nil
|
||||
})}
|
||||
|
||||
gotImmediate := cache.GetSWR(context.Background())
|
||||
if !slices.Equal(gotImmediate, old) {
|
||||
t.Fatalf("GetSWR should return current cache immediately: got %#v, want %#v", gotImmediate, old)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-refreshDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for async refresh")
|
||||
}
|
||||
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
return slices.Equal(cache.Get(), newRecs)
|
||||
})
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRSkipsWhenRefreshAlreadyInFlight(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
n := calls.Add(1)
|
||||
if n == 1 {
|
||||
close(started)
|
||||
}
|
||||
<-release
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first refresh call")
|
||||
}
|
||||
|
||||
for range 5 {
|
||||
cache.GetSWR(context.Background())
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during in-flight refresh = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(release)
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRThrottlesRefreshAfterCompletion(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
<-release
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first refresh call")
|
||||
}
|
||||
|
||||
time.Sleep(2 * modelRecommendationsReadRefreshCooldown)
|
||||
close(release)
|
||||
waitForCacheIdle(t, cache)
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during read refresh cooldown = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRRetriesAfterReadRefreshCooldown(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
|
||||
cache.set(old)
|
||||
|
||||
var calls atomic.Int32
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
return nil, errors.New("temporary upstream failure")
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"recovered","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
waitForCondition(t, 2*time.Second, func() bool { return calls.Load() >= 1 })
|
||||
waitForCacheIdle(t, cache)
|
||||
|
||||
if !slices.Equal(cache.Get(), old) {
|
||||
t.Fatalf("cache should remain unchanged after failed refresh, got %#v", cache.Get())
|
||||
}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during read refresh cooldown after failure = %d, want 1", got)
|
||||
}
|
||||
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
cache.GetSWR(context.Background())
|
||||
return calls.Load() >= 2
|
||||
})
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
return slices.Equal(cache.Get(), []api.ModelRecommendation{{Model: "recovered", Description: "ok"}})
|
||||
})
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func jsonHTTPResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func setupModelRecommendationsTestEnv(t *testing.T, noCloudEnv string) {
|
||||
t.Helper()
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
t.Setenv("HOMEDRIVE", filepath.VolumeName(home))
|
||||
t.Setenv("HOMEPATH", strings.TrimPrefix(home, filepath.VolumeName(home)))
|
||||
|
||||
// Use explicit false rather than empty to avoid platform/env ambiguity.
|
||||
if noCloudEnv == "" {
|
||||
noCloudEnv = "false"
|
||||
}
|
||||
t.Setenv("OLLAMA_NO_CLOUD", noCloudEnv)
|
||||
envconfig.ReloadServerConfig()
|
||||
t.Cleanup(envconfig.ReloadServerConfig)
|
||||
}
|
||||
|
||||
func withModelRecommendationsReadRefreshCooldown(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
old := modelRecommendationsReadRefreshCooldown
|
||||
modelRecommendationsReadRefreshCooldown = d
|
||||
t.Cleanup(func() {
|
||||
modelRecommendationsReadRefreshCooldown = old
|
||||
})
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timed out waiting for condition")
|
||||
}
|
||||
|
||||
func waitForCacheIdle(t *testing.T, cache *modelRecommendationsCache) {
|
||||
t.Helper()
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
cache.mu.RLock()
|
||||
refreshing := cache.refreshing
|
||||
cache.mu.RUnlock()
|
||||
return !refreshing
|
||||
})
|
||||
}
|
||||
|
||||
func decodeRecommendationNames(t *testing.T, w *httptest.ResponseRecorder) []string {
|
||||
t.Helper()
|
||||
|
||||
var resp struct {
|
||||
Recommendations []struct {
|
||||
Model string `json:"model"`
|
||||
} `json:"recommendations"`
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(resp.Recommendations))
|
||||
for _, rec := range resp.Recommendations {
|
||||
names = append(names, rec.Model)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func modelRecommendationNames(recs []api.ModelRecommendation) []string {
|
||||
names := make([]string, len(recs))
|
||||
for i, rec := range recs {
|
||||
names[i] = rec.Model
|
||||
}
|
||||
return names
|
||||
}
|
||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelSource = modelref.ModelSource
|
||||
|
||||
const (
|
||||
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||
errModelRequired = modelref.ErrModelRequired
|
||||
)
|
||||
|
||||
type parsedModelRef struct {
|
||||
// Original is the caller-provided model string before source parsing.
|
||||
// Example: "gpt-oss:20b:cloud".
|
||||
Original string
|
||||
// Base is the model string after source suffix normalization.
|
||||
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||
Base string
|
||||
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||
Name model.Name
|
||||
// Source captures explicit source intent from the original input.
|
||||
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||
Source modelSource
|
||||
}
|
||||
|
||||
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsed, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(parsed.Base)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsed.Original,
|
||||
Base: parsed.Base,
|
||||
Name: name,
|
||||
Source: parsed.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsedRef, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(normalizedName)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsedRef.Original,
|
||||
Base: normalizedName,
|
||||
Name: name,
|
||||
Source: parsedRef.Source,
|
||||
}, nil
|
||||
}
|
||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelSelector(t *testing.T) {
|
||||
t.Run("cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "my-cloud-model" {
|
||||
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "qwen3:8b" {
|
||||
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||
if !errors.Is(err, errConflictingModelSource) {
|
||||
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified source", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "latest" {
|
||||
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "clod" {
|
||||
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("")
|
||||
if !errors.Is(err, errModelRequired) {
|
||||
t.Fatalf("expected errModelRequired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("::cloud")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unqualified") {
|
||||
t.Fatalf("expected unqualified model error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePullModelRef(t *testing.T) {
|
||||
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "qwen3:cloud" {
|
||||
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
692
server/model_show_cache.go
Normal file
692
server/model_show_cache.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
/*
|
||||
The /api/show cache stores full api.ShowResponse values because callers use
|
||||
more than capabilities: launch flows also need context length, embeddings
|
||||
metadata, quantization details, remote metadata, and model-specific fields.
|
||||
TODO(parthsareen): Consider removing show cache if /api/tags grows to cover
|
||||
the remaining callers.
|
||||
|
||||
Local model entries are stored lazily by canonical model name and verbose flag,
|
||||
with the manifest digest recorded in the entry. The manifest digest is the
|
||||
freshness boundary: if the model content changes, the digest changes, so the
|
||||
previous response is replaced instead of accumulating under an old digest key.
|
||||
Requests with System or Options overlays bypass the cache because those overlays
|
||||
mutate the effective show response.
|
||||
|
||||
Cloud model entries are keyed by normalized cloud base model name and verbose.
|
||||
They use stale-while-revalidate behavior: a warm read returns the cached
|
||||
response immediately and starts a throttled background refresh for that model.
|
||||
Cold cloud reads preserve existing proxy behavior. Local and cloud entries live
|
||||
in separate maps, so a local "qwen3.5" and an explicit "qwen3.5:cloud" cannot
|
||||
collide. The cloud suffix is request routing intent; api.ShowResponse does not
|
||||
carry a model-name field to reconstruct on the way out.
|
||||
|
||||
The cache is process-local. Cloud startup hydration runs asynchronously from
|
||||
cloud tags, while local show responses are populated on demand. No show
|
||||
responses are written to or read from ~/.ollama/cache/show. That keeps cache
|
||||
lifetime tied to the server process and avoids snapshot freshness and
|
||||
invalidation cases for this iteration.
|
||||
*/
|
||||
|
||||
const (
|
||||
modelShowCloudFetchTimeout = 3 * time.Second
|
||||
modelShowCloudReadRefreshCooldown = 5 * time.Second
|
||||
modelShowCloudHydrationConcurrency = 4
|
||||
)
|
||||
|
||||
var errModelShowNoCloud = errors.New("cloud disabled")
|
||||
|
||||
// modelShowCache owns process-local show response caches for local and cloud
|
||||
// models. All cached responses are cloned at read/write boundaries so
|
||||
// handler-specific mutations, such as user-agent compatibility tweaks, cannot
|
||||
// leak back into the cache.
|
||||
type modelShowCache struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
local map[modelShowLocalKey]modelShowLocalEntry
|
||||
cloud map[modelShowCloudKey]*api.ShowResponse
|
||||
|
||||
cloudRefreshing map[modelShowCloudKey]bool
|
||||
cloudNextReadRefreshAfter map[modelShowCloudKey]time.Time
|
||||
|
||||
once sync.Once
|
||||
client *http.Client
|
||||
getModelInfo func(api.ShowRequest) (*api.ShowResponse, error)
|
||||
}
|
||||
|
||||
// modelShowLocalKey describes the local cache slot for a model response. The
|
||||
// manifest digest is stored in the entry instead of the key so a pulled or
|
||||
// recreated model overwrites the previous response for the same model/verbose
|
||||
// variant instead of leaving stale digest-keyed entries behind.
|
||||
//
|
||||
// Deleted models are not eagerly pruned from this process-local cache. Manifest
|
||||
// resolution happens before local cache lookup, so stale delete entries are not
|
||||
// served and disappear on process restart.
|
||||
type modelShowLocalKey struct {
|
||||
Model string
|
||||
Verbose bool
|
||||
}
|
||||
|
||||
type modelShowLocalEntry struct {
|
||||
Digest string
|
||||
Response *api.ShowResponse
|
||||
}
|
||||
|
||||
// modelShowCloudKey intentionally excludes any local digest because cloud
|
||||
// models are refreshed through SWR and normalized by cloud base model name.
|
||||
type modelShowCloudKey struct {
|
||||
Model string
|
||||
Verbose bool
|
||||
}
|
||||
|
||||
func newModelShowCache() *modelShowCache {
|
||||
return &modelShowCache{
|
||||
local: make(map[modelShowLocalKey]modelShowLocalEntry),
|
||||
cloud: make(map[modelShowCloudKey]*api.ShowResponse),
|
||||
cloudRefreshing: make(map[modelShowCloudKey]bool),
|
||||
cloudNextReadRefreshAfter: make(map[modelShowCloudKey]time.Time),
|
||||
client: http.DefaultClient,
|
||||
getModelInfo: GetModelInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// modelShowCacheable returns whether a request can use the shared show cache.
|
||||
// System and Options overlays are request-specific response variants, so v1
|
||||
// bypasses caching for those rather than expanding the key space.
|
||||
func modelShowCacheable(req api.ShowRequest) bool {
|
||||
return req.System == "" && len(req.Options) == 0
|
||||
}
|
||||
|
||||
// Start kicks off non-blocking startup hydration for cloud entries. Local show
|
||||
// responses stay lazy because even non-verbose show must load GGUF metadata that
|
||||
// is expensive for large model stores.
|
||||
func (c *modelShowCache) Start(ctx context.Context) {
|
||||
c.once.Do(func() {
|
||||
slog.Debug("starting model show cache")
|
||||
go c.runStartup(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// runStartup hydrates the cloud cache. It is only called in a goroutine from
|
||||
// Start, so cloud requests cannot delay the listener from accepting traffic.
|
||||
func (c *modelShowCache) runStartup(ctx context.Context) {
|
||||
if err := c.hydrateCloud(ctx); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
case errors.Is(err, errModelShowNoCloud):
|
||||
slog.Debug("skipping model show cloud cache hydration because cloud is disabled")
|
||||
default:
|
||||
slog.Warn("model show cloud cache hydration failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetLocal returns a cached local show response when the current manifest
|
||||
// digest matches. On a miss, it falls back to GetModelInfo, stores non-remote
|
||||
// local responses, and returns a clone to the caller.
|
||||
func (c *modelShowCache) GetLocal(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
key, digest, err := modelShowLocalKeyForRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp, ok := c.getLocal(key, digest); ok {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
req.Model = key.Model
|
||||
resp, err := c.getModelInfo(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.RemoteHost == "" {
|
||||
c.setLocal(key, digest, resp)
|
||||
}
|
||||
|
||||
return cloneShowResponse(resp), nil
|
||||
}
|
||||
|
||||
// GetCloudSWR returns a cached cloud show response and triggers a throttled
|
||||
// background refresh. The boolean is false on a cold miss so callers can
|
||||
// preserve existing synchronous proxy behavior.
|
||||
func (c *modelShowCache) GetCloudSWR(ctx context.Context, req api.ShowRequest) (*api.ShowResponse, bool) {
|
||||
key := modelShowCloudKeyForModel(req.Model, req.Verbose)
|
||||
resp, ok := c.getCloud(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.triggerCloudRefreshOnRead(ctx, key)
|
||||
return resp, true
|
||||
}
|
||||
|
||||
func (c *modelShowCache) getLocal(key modelShowLocalKey, digest string) (*api.ShowResponse, bool) {
|
||||
c.mu.RLock()
|
||||
entry, ok := c.local[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok || entry.Digest != digest || entry.Response == nil {
|
||||
return nil, false
|
||||
}
|
||||
return cloneShowResponse(entry.Response), true
|
||||
}
|
||||
|
||||
func (c *modelShowCache) setLocal(key modelShowLocalKey, digest string, resp *api.ShowResponse) {
|
||||
c.mu.Lock()
|
||||
c.local[key] = modelShowLocalEntry{
|
||||
Digest: digest,
|
||||
Response: cloneShowResponse(resp),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelShowCache) hasLocal(key modelShowLocalKey, digest string) bool {
|
||||
c.mu.RLock()
|
||||
entry, ok := c.local[key]
|
||||
c.mu.RUnlock()
|
||||
return ok && entry.Digest == digest && entry.Response != nil
|
||||
}
|
||||
|
||||
func (c *modelShowCache) getCloud(key modelShowCloudKey) (*api.ShowResponse, bool) {
|
||||
c.mu.RLock()
|
||||
resp, ok := c.cloud[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok || resp == nil {
|
||||
return nil, false
|
||||
}
|
||||
return cloneShowResponse(resp), true
|
||||
}
|
||||
|
||||
func (c *modelShowCache) setCloud(key modelShowCloudKey, resp *api.ShowResponse) {
|
||||
c.mu.Lock()
|
||||
c.cloud[key] = cloneShowResponse(resp)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelShowCache) beginCloudReadRefresh(key modelShowCloudKey) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.cloudRefreshing[key] || now.Before(c.cloudNextReadRefreshAfter[key]) {
|
||||
return false
|
||||
}
|
||||
|
||||
c.cloudRefreshing[key] = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *modelShowCache) endCloudReadRefresh(key modelShowCloudKey) {
|
||||
c.mu.Lock()
|
||||
c.cloudRefreshing[key] = false
|
||||
c.cloudNextReadRefreshAfter[key] = time.Now().Add(modelShowCloudReadRefreshCooldown)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// triggerCloudRefreshOnRead starts the revalidation side of SWR. The refresh
|
||||
// uses context.WithoutCancel so a completed client request does not cancel the
|
||||
// cache update it initiated.
|
||||
func (c *modelShowCache) triggerCloudRefreshOnRead(ctx context.Context, key modelShowCloudKey) {
|
||||
if !c.beginCloudReadRefresh(key) {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
|
||||
slog.Debug("triggering model show cloud refresh on read", "model", key.Model, "verbose", key.Verbose)
|
||||
go func() {
|
||||
defer c.endCloudReadRefresh(key)
|
||||
|
||||
if err := c.refreshCloud(ctx, key); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errModelShowNoCloud):
|
||||
slog.Debug("skipping model show cloud read refresh because cloud is disabled", "model", key.Model)
|
||||
default:
|
||||
slog.Warn("model show cloud read refresh failed", "model", key.Model, "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// refreshCloud fetches and stores one cloud show response. Refresh failures are
|
||||
// returned without touching the existing cached entry, which preserves stale
|
||||
// data for future reads.
|
||||
func (c *modelShowCache) refreshCloud(ctx context.Context, key modelShowCloudKey) error {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
return errModelShowNoCloud
|
||||
}
|
||||
|
||||
resp, err := c.fetchCloudShow(ctx, key.Model, key.Verbose)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setCloud(key, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// hydrateLocal scans manifests at startup and refreshes only entries missing
|
||||
// for the current digest. It hydrates non-verbose responses only, avoiding an
|
||||
// expensive tensor walk for users who have never asked for verbose show data.
|
||||
func (c *modelShowCache) hydrateLocal(ctx context.Context) error {
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, mf := range manifests {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if modelShowManifestIsRemote(mf) {
|
||||
continue
|
||||
}
|
||||
|
||||
modelName := name.String()
|
||||
digest := mf.Digest()
|
||||
key := modelShowLocalKey{
|
||||
Model: modelName,
|
||||
Verbose: false,
|
||||
}
|
||||
if c.hasLocal(key, digest) {
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := c.getModelInfo(api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
slog.Warn("failed to hydrate local model show cache", "model", modelName, "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.RemoteHost != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
c.setLocal(key, digest, resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hydrateCloud refreshes cloud show entries by listing cloud tags and fetching
|
||||
// /api/show for each returned model with bounded concurrency. Per-model show
|
||||
// failures are logged and skipped so one bad cloud entry does not prevent the
|
||||
// rest of the cache from warming.
|
||||
func (c *modelShowCache) hydrateCloud(ctx context.Context) error {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
return errModelShowNoCloud
|
||||
}
|
||||
|
||||
models, err := c.fetchCloudTags(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jobs := make(chan string)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
worker := func() {
|
||||
defer wg.Done()
|
||||
for modelName := range jobs {
|
||||
if ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key := modelShowCloudKeyForModel(modelName, false)
|
||||
resp, err := c.fetchCloudShow(ctx, key.Model, key.Verbose)
|
||||
if err != nil {
|
||||
slog.Warn("failed to hydrate cloud model show cache", "model", key.Model, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.setCloud(key, resp)
|
||||
}
|
||||
}
|
||||
|
||||
workers := min(modelShowCloudHydrationConcurrency, max(1, len(models)))
|
||||
for range workers {
|
||||
wg.Add(1)
|
||||
go worker()
|
||||
}
|
||||
|
||||
sendLoop:
|
||||
for _, modelName := range models {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break sendLoop
|
||||
case jobs <- modelName:
|
||||
}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchCloudTags returns de-duplicated cloud model names normalized to their
|
||||
// show-cache key form. It accepts either ListModelResponse.Model or the legacy
|
||||
// Name field because /api/tags responses may contain both.
|
||||
func (c *modelShowCache) fetchCloudTags(ctx context.Context) ([]string, error) {
|
||||
var payload api.ListResponse
|
||||
if err := c.doCloudJSON(ctx, http.MethodGet, "/api/tags", nil, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(payload.Models))
|
||||
models := make([]string, 0, len(payload.Models))
|
||||
for _, item := range payload.Models {
|
||||
name := strings.TrimSpace(item.Model)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(item.Name)
|
||||
}
|
||||
name = modelShowNormalizeCloudModel(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
models = append(models, name)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (c *modelShowCache) fetchCloudShow(ctx context.Context, modelName string, verbose bool) (*api.ShowResponse, error) {
|
||||
payload := api.ShowRequest{
|
||||
Model: modelShowNormalizeCloudModel(modelName),
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := c.doCloudJSON(ctx, http.MethodPost, "/api/show", payload, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.ModelInfo == nil {
|
||||
resp.ModelInfo = map[string]any{}
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// doCloudJSON is the cache's direct cloud client. It mirrors the cloud proxy's
|
||||
// signing and client-version behavior but uses an internal timeout because
|
||||
// hydration and refreshes must not hang indefinitely.
|
||||
func (c *modelShowCache) doCloudJSON(ctx context.Context, method, path string, payload any, out any) error {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelShowCloudFetchTimeout)
|
||||
defer cancel()
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targetURL := baseURL.ResolveReference(&url.URL{Path: path})
|
||||
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, method, targetURL.String(), body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if clientVersion := strings.TrimSpace(version.Version); clientVersion != "" {
|
||||
req.Header.Set(cloudProxyClientVersionHeader, clientVersion)
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(req.Context(), req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
return modelShowStatusError(resp, data)
|
||||
}
|
||||
|
||||
if out == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, out)
|
||||
}
|
||||
|
||||
// modelShowStatusError preserves the important error shape from cloud
|
||||
// responses, including AuthorizationError for 401s and StatusError otherwise.
|
||||
func modelShowStatusError(resp *http.Response, body []byte) error {
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
err := api.AuthorizationError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
}
|
||||
_ = json.Unmarshal(body, &err)
|
||||
if err.Status == "" {
|
||||
err.Status = resp.Status
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
statusErr := api.StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
}
|
||||
if err := json.Unmarshal(body, &statusErr); err != nil || statusErr.ErrorMessage == "" {
|
||||
statusErr.ErrorMessage = strings.TrimSpace(string(body))
|
||||
}
|
||||
return statusErr
|
||||
}
|
||||
|
||||
// modelShowLocalKeyForRequest normalizes a local show request to the canonical
|
||||
// on-disk model name and returns the current manifest digest used to validate
|
||||
// the cached entry.
|
||||
func modelShowLocalKeyForRequest(req api.ShowRequest) (modelShowLocalKey, string, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return modelShowLocalKey{}, "", model.Unqualified(name)
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
return modelShowLocalKey{}, "", err
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return modelShowLocalKey{}, "", err
|
||||
}
|
||||
|
||||
return modelShowLocalKey{
|
||||
Model: name.String(),
|
||||
Verbose: req.Verbose,
|
||||
}, mf.Digest(), nil
|
||||
}
|
||||
|
||||
func modelShowCloudKeyForModel(modelName string, verbose bool) modelShowCloudKey {
|
||||
return modelShowCloudKey{
|
||||
Model: modelShowNormalizeCloudModel(modelName),
|
||||
Verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// modelShowNormalizeCloudModel strips explicit cloud source syntax, including
|
||||
// legacy "-cloud" tags, so :cloud and -cloud forms share a cache entry.
|
||||
func modelShowNormalizeCloudModel(modelName string) string {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
if base, stripped := modelref.StripCloudSourceTag(modelName); stripped {
|
||||
return strings.TrimSpace(base)
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
// modelShowManifestIsRemote checks whether a manifest represents a local stub
|
||||
// for a remote model. Startup hydration skips these so the local content cache
|
||||
// does not store entries whose freshness is governed by cloud state.
|
||||
func modelShowManifestIsRemote(mf *manifest.Manifest) bool {
|
||||
if mf == nil || mf.Config.Digest == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
f, err := mf.Config.Open()
|
||||
if err != nil {
|
||||
slog.Warn("failed to open manifest config while checking model show cache eligibility", "error", err)
|
||||
return false
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
|
||||
slog.Warn("failed to decode manifest config while checking model show cache eligibility", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.RemoteHost != "" || cfg.RemoteModel != ""
|
||||
}
|
||||
|
||||
// cloneShowResponse deep-copies mutable fields of api.ShowResponse before
|
||||
// storing or returning cached entries. The response contains maps and slices,
|
||||
// and some handlers mutate ModelInfo before writing JSON.
|
||||
func cloneShowResponse(in *api.ShowResponse) *api.ShowResponse {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := *in
|
||||
out.Details.Families = slices.Clone(in.Details.Families)
|
||||
out.Messages = cloneMessages(in.Messages)
|
||||
out.Capabilities = slices.Clone(in.Capabilities)
|
||||
out.ModelInfo = cloneAnyMap(in.ModelInfo)
|
||||
out.ProjectorInfo = cloneAnyMap(in.ProjectorInfo)
|
||||
out.Tensors = cloneTensors(in.Tensors)
|
||||
return &out
|
||||
}
|
||||
|
||||
func cloneMessages(in []api.Message) []api.Message {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]api.Message, len(in))
|
||||
for i, msg := range in {
|
||||
out[i] = msg
|
||||
if msg.Images != nil {
|
||||
out[i].Images = make([]api.ImageData, len(msg.Images))
|
||||
for j, image := range msg.Images {
|
||||
out[i].Images[j] = slices.Clone(image)
|
||||
}
|
||||
}
|
||||
out[i].ToolCalls = slices.Clone(msg.ToolCalls)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneTensors(in []api.Tensor) []api.Tensor {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]api.Tensor, len(in))
|
||||
for i, tensor := range in {
|
||||
out[i] = tensor
|
||||
out[i].Shape = slices.Clone(tensor.Shape)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAnyMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = cloneAny(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneAny(v any) any {
|
||||
switch v := v.(type) {
|
||||
case map[string]any:
|
||||
return cloneAnyMap(v)
|
||||
case []any:
|
||||
out := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
out[i] = cloneAny(item)
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
return slices.Clone(v)
|
||||
case []bool:
|
||||
return slices.Clone(v)
|
||||
case []int:
|
||||
return slices.Clone(v)
|
||||
case []int8:
|
||||
return slices.Clone(v)
|
||||
case []int16:
|
||||
return slices.Clone(v)
|
||||
case []int32:
|
||||
return slices.Clone(v)
|
||||
case []int64:
|
||||
return slices.Clone(v)
|
||||
case []uint:
|
||||
return slices.Clone(v)
|
||||
case []uint8:
|
||||
return slices.Clone(v)
|
||||
case []uint16:
|
||||
return slices.Clone(v)
|
||||
case []uint32:
|
||||
return slices.Clone(v)
|
||||
case []uint64:
|
||||
return slices.Clone(v)
|
||||
case []float32:
|
||||
return slices.Clone(v)
|
||||
case []float64:
|
||||
return slices.Clone(v)
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
520
server/model_show_cache_test.go
Normal file
520
server/model_show_cache_test.go
Normal file
@@ -0,0 +1,520 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
modelpkg "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestModelShowCacheLocalHitUsesManifestDigest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-local", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
calls := 0
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
calls++
|
||||
return showCacheTestResponse(calls, req.Verbose), nil
|
||||
}
|
||||
|
||||
first, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-local"})
|
||||
if err != nil {
|
||||
t.Fatalf("first GetLocal failed: %v", err)
|
||||
}
|
||||
second, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-local"})
|
||||
if err != nil {
|
||||
t.Fatalf("second GetLocal failed: %v", err)
|
||||
}
|
||||
|
||||
if calls != 1 {
|
||||
t.Fatalf("getModelInfo calls = %d, want 1", calls)
|
||||
}
|
||||
if first.ModelInfo["call"] != 1 || second.ModelInfo["call"] != 1 {
|
||||
t.Fatalf("cached call markers = %v / %v, want both 1", first.ModelInfo["call"], second.ModelInfo["call"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheLocalManifestDigestChangeRefreshes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-refresh", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
calls := 0
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
calls++
|
||||
return showCacheTestResponse(calls, req.Verbose), nil
|
||||
}
|
||||
|
||||
if _, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-refresh"}); err != nil {
|
||||
t.Fatalf("first GetLocal failed: %v", err)
|
||||
}
|
||||
changeShowCacheManifest(t, "show-cache-refresh")
|
||||
refreshed, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-refresh"})
|
||||
if err != nil {
|
||||
t.Fatalf("refreshed GetLocal failed: %v", err)
|
||||
}
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("getModelInfo calls = %d, want 2", calls)
|
||||
}
|
||||
if refreshed.ModelInfo["call"] != 2 {
|
||||
t.Fatalf("refreshed call marker = %v, want 2", refreshed.ModelInfo["call"])
|
||||
}
|
||||
if len(cache.local) != 1 {
|
||||
t.Fatalf("local cache entries = %d, want 1", len(cache.local))
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheLocalVerboseVariantsAreSeparate(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-verbose", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
calls := 0
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
calls++
|
||||
return showCacheTestResponse(calls, req.Verbose), nil
|
||||
}
|
||||
|
||||
plain, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose"})
|
||||
if err != nil {
|
||||
t.Fatalf("plain GetLocal failed: %v", err)
|
||||
}
|
||||
verbose, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose", Verbose: true})
|
||||
if err != nil {
|
||||
t.Fatalf("verbose GetLocal failed: %v", err)
|
||||
}
|
||||
plainAgain, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose"})
|
||||
if err != nil {
|
||||
t.Fatalf("plain repeat GetLocal failed: %v", err)
|
||||
}
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("getModelInfo calls = %d, want 2", calls)
|
||||
}
|
||||
if plain.ModelInfo["verbose"] != false || verbose.ModelInfo["verbose"] != true || plainAgain.ModelInfo["call"] != 1 {
|
||||
t.Fatalf("unexpected verbose cache markers: plain=%v verbose=%v plainAgainCall=%v", plain.ModelInfo, verbose.ModelInfo, plainAgain.ModelInfo["call"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheLocalHydrationSkipsUnchangedInMemory(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-hydrate", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
calls := 0
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
calls++
|
||||
return showCacheTestResponse(calls, req.Verbose), nil
|
||||
}
|
||||
|
||||
if err := cache.hydrateLocal(context.Background()); err != nil {
|
||||
t.Fatalf("first hydrateLocal failed: %v", err)
|
||||
}
|
||||
if err := cache.hydrateLocal(context.Background()); err != nil {
|
||||
t.Fatalf("second hydrateLocal failed: %v", err)
|
||||
}
|
||||
resp, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-hydrate"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetLocal after hydration failed: %v", err)
|
||||
}
|
||||
|
||||
if calls != 1 {
|
||||
t.Fatalf("getModelInfo calls after unchanged in-memory hydration = %d, want 1", calls)
|
||||
}
|
||||
if resp.ModelInfo["call"] != 1 {
|
||||
t.Fatalf("hydrated call marker = %v, want 1", resp.ModelInfo["call"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheStartupSkipsLocalHydration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
createShowCacheModel(t, "show-cache-startup", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
t.Fatalf("startup should not hydrate local show cache, got request: %+v", req)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cache.runStartup(context.Background())
|
||||
|
||||
if len(cache.local) != 0 {
|
||||
t.Fatalf("local cache entries = %d, want 0", len(cache.local))
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheBypassesSystemAndOptionsOverlays(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-overlay", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
key, digest, err := modelShowLocalKeyForRequest(api.ShowRequest{Model: "show-cache-overlay"})
|
||||
if err != nil {
|
||||
t.Fatalf("local key failed: %v", err)
|
||||
}
|
||||
cache.setLocal(key, digest, &api.ShowResponse{System: "cached", ModelInfo: map[string]any{}})
|
||||
|
||||
s := Server{modelCaches: &modelCaches{show: cache}}
|
||||
w := createRequest(t, s.ShowHandler, api.ShowRequest{
|
||||
Model: "show-cache-overlay",
|
||||
System: "overlay-system",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if resp.System != "overlay-system" {
|
||||
t.Fatalf("system = %q, want overlay-system", resp.System)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{
|
||||
Model: "show-cache-overlay",
|
||||
Options: map[string]any{"num_ctx": float64(8192)},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("options overlay status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode options response failed: %v", err)
|
||||
}
|
||||
if resp.System == "cached" {
|
||||
t.Fatalf("options overlay unexpectedly returned cached response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheLocalAndCloudSameBaseDoNotCollide(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
createShowCacheModel(t, "show-cache-dual", map[string]any{"test.context_length": uint32(1024)})
|
||||
|
||||
cache := newModelShowCache()
|
||||
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return &api.ShowResponse{
|
||||
Details: api.ModelDetails{Format: "local"},
|
||||
ModelInfo: map[string]any{},
|
||||
}, nil
|
||||
}
|
||||
cloudKey := modelShowCloudKeyForModel("show-cache-dual:cloud", false)
|
||||
cache.setCloud(cloudKey, &api.ShowResponse{
|
||||
Details: api.ModelDetails{Format: "cloud"},
|
||||
ModelInfo: map[string]any{},
|
||||
})
|
||||
cache.mu.Lock()
|
||||
cache.cloudNextReadRefreshAfter[cloudKey] = time.Now().Add(time.Hour)
|
||||
cache.mu.Unlock()
|
||||
|
||||
s := Server{modelCaches: &modelCaches{show: cache}}
|
||||
|
||||
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "show-cache-dual:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("cloud status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var cloudResp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&cloudResp); err != nil {
|
||||
t.Fatalf("decode cloud response failed: %v", err)
|
||||
}
|
||||
if cloudResp.Details.Format != "cloud" {
|
||||
t.Fatalf("cloud format = %q, want cloud", cloudResp.Details.Format)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "show-cache-dual"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("local status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var localResp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&localResp); err != nil {
|
||||
t.Fatalf("decode local response failed: %v", err)
|
||||
}
|
||||
if localResp.Details.Format != "local" {
|
||||
t.Fatalf("local format = %q, want local", localResp.Details.Format)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheCloudWarmHitReturnsStaleAndRefreshes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
refreshDone := make(chan struct{})
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" {
|
||||
t.Fatalf("unexpected upstream path %q", r.URL.Path)
|
||||
}
|
||||
defer close(refreshDone)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"details":{"format":"updated"},"model_info":{"source":"fresh"}}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
withCloudProxyBaseURL(t, upstream.URL)
|
||||
|
||||
cache := newModelShowCache()
|
||||
cache.client = upstream.Client()
|
||||
cache.setCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false), &api.ShowResponse{
|
||||
Details: api.ModelDetails{Format: "cached"},
|
||||
ModelInfo: map[string]any{"source": "stale"},
|
||||
})
|
||||
|
||||
s := Server{modelCaches: &modelCaches{show: cache}}
|
||||
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "kimi-k2.5:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if resp.Details.Format != "cached" {
|
||||
t.Fatalf("format = %q, want cached", resp.Details.Format)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-refreshDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for cloud refresh")
|
||||
}
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
resp, ok := cache.getCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false))
|
||||
return ok && resp.Details.Format == "updated"
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelShowCacheCloudColdMissFallsBackToProxy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
var capturedPath, capturedBody string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
capturedBody = string(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"details":{"format":"cold"},"model_info":{}}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
withCloudProxyBaseURL(t, upstream.URL)
|
||||
|
||||
s := &Server{modelCaches: &modelCaches{show: newModelShowCache()}}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRoutes failed: %v", err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(`{"model":"kimi-k2.5:cloud"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("status = %d, want 200: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
if capturedPath != "/api/show" {
|
||||
t.Fatalf("upstream path = %q, want /api/show", capturedPath)
|
||||
}
|
||||
if !strings.Contains(capturedBody, `"model":"kimi-k2.5"`) {
|
||||
t.Fatalf("expected normalized model in upstream body, got %q", capturedBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheCloudHydrationUsesTagsAndShow(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
var mu sync.Mutex
|
||||
var showModels []string
|
||||
var tagsCalled bool
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("tags method = %s, want GET", r.Method)
|
||||
}
|
||||
mu.Lock()
|
||||
tagsCalled = true
|
||||
mu.Unlock()
|
||||
_, _ = w.Write([]byte(`{"models":[{"name":"alpha:cloud"},{"model":"beta"}]}`))
|
||||
case "/api/show":
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode show request failed: %v", err)
|
||||
}
|
||||
mu.Lock()
|
||||
showModels = append(showModels, req.Model)
|
||||
mu.Unlock()
|
||||
_ = json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Details: api.ModelDetails{Format: req.Model},
|
||||
ModelInfo: map[string]any{"model": req.Model},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected upstream path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
withCloudProxyBaseURL(t, upstream.URL)
|
||||
|
||||
cache := newModelShowCache()
|
||||
cache.client = upstream.Client()
|
||||
if err := cache.hydrateCloud(context.Background()); err != nil {
|
||||
t.Fatalf("hydrateCloud failed: %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
gotTagsCalled := tagsCalled
|
||||
gotShowModels := slices.Clone(showModels)
|
||||
mu.Unlock()
|
||||
slices.Sort(gotShowModels)
|
||||
|
||||
if !gotTagsCalled {
|
||||
t.Fatal("expected /api/tags to be called")
|
||||
}
|
||||
if !slices.Equal(gotShowModels, []string{"alpha", "beta"}) {
|
||||
t.Fatalf("show models = %v, want [alpha beta]", gotShowModels)
|
||||
}
|
||||
for _, modelName := range gotShowModels {
|
||||
resp, ok := cache.getCloud(modelShowCloudKeyForModel(modelName, false))
|
||||
if !ok {
|
||||
t.Fatalf("missing cached cloud show response for %s", modelName)
|
||||
}
|
||||
if resp.Details.Format != modelName {
|
||||
t.Fatalf("cached format for %s = %q", modelName, resp.Details.Format)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheCloudKeyNormalizesSourceTags(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
" kimi-k2.5:cloud ": "kimi-k2.5",
|
||||
"gpt-oss:20b-cloud": "gpt-oss:20b",
|
||||
"qwen3": "qwen3",
|
||||
}
|
||||
|
||||
for input, want := range tests {
|
||||
if got := modelShowCloudKeyForModel(input, false).Model; got != want {
|
||||
t.Fatalf("cloud key model for %q = %q, want %q", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelShowCacheCloudDisabledDoesNotServeStale(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
t.Cleanup(envconfig.ReloadServerConfig)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
envconfig.ReloadServerConfig()
|
||||
|
||||
cache := newModelShowCache()
|
||||
cache.setCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false), &api.ShowResponse{
|
||||
Details: api.ModelDetails{Format: "cached"},
|
||||
ModelInfo: map[string]any{},
|
||||
})
|
||||
|
||||
if err := cache.hydrateCloud(context.Background()); !errors.Is(err, errModelShowNoCloud) {
|
||||
t.Fatalf("hydrateCloud error = %v, want %v", err, errModelShowNoCloud)
|
||||
}
|
||||
|
||||
s := Server{modelCaches: &modelCaches{show: cache}}
|
||||
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "kimi-k2.5:cloud"})
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want 403: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable)) {
|
||||
t.Fatalf("unexpected disabled response: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func createShowCacheModel(t *testing.T, name string, kv map[string]any) {
|
||||
t.Helper()
|
||||
_, digest := createBinFile(t, kv, nil)
|
||||
var s Server
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: name,
|
||||
Files: map[string]string{"model.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func changeShowCacheManifest(t *testing.T, name string) {
|
||||
t.Helper()
|
||||
n, err := getExistingName(modelpkg.ParseName(name))
|
||||
if err != nil {
|
||||
t.Fatalf("get existing name: %v", err)
|
||||
}
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
layer, err := manifest.NewLayer(strings.NewReader("changed"), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
t.Fatalf("new layer: %v", err)
|
||||
}
|
||||
layers := append([]manifest.Layer(nil), mf.Layers...)
|
||||
layers = append(layers, layer)
|
||||
if err := manifest.WriteManifest(n, mf.Config, layers); err != nil {
|
||||
t.Fatalf("write manifest: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func showCacheTestResponse(call int, verbose bool) *api.ShowResponse {
|
||||
return &api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Format: "gguf",
|
||||
Family: "test",
|
||||
},
|
||||
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion},
|
||||
ModelInfo: map[string]any{
|
||||
"call": call,
|
||||
"verbose": verbose,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func withCloudProxyBaseURL(t *testing.T, url string) {
|
||||
t.Helper()
|
||||
original := cloudProxyBaseURL
|
||||
cloudProxyBaseURL = url
|
||||
t.Cleanup(func() {
|
||||
cloudProxyBaseURL = original
|
||||
})
|
||||
}
|
||||
144
server/prompt.go
Normal file
144
server/prompt.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
// Clip images are represented as 768 tokens, each an embedding
|
||||
imageNumTokens := 768
|
||||
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if currMsgIdx > 0 {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
renderMsgs := slices.Clone(msgs)
|
||||
|
||||
for cnt, msg := range renderMsgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||
}
|
||||
|
||||
var prefix string
|
||||
prompt := msg.Content
|
||||
|
||||
for _, i := range msg.Images {
|
||||
imgData := llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
}
|
||||
images = append(images, imgData)
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||
if !strings.Contains(prompt, "[img]") {
|
||||
prefix += imgTag
|
||||
} else {
|
||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||
}
|
||||
}
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
renderMsgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, renderMsgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return p, images, nil
|
||||
}
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
if m.Config.Renderer != "" {
|
||||
rendererName := resolveRendererName(m)
|
||||
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rendered, nil
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
if think != nil {
|
||||
thinkVal = think.Bool()
|
||||
thinkLevel = think.String()
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
606
server/prompt_test.go
Normal file
606
server/prompt_test.go
Normal file
@@ -0,0 +1,606 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func testConfigWithRenderer(renderer string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer}
|
||||
}
|
||||
|
||||
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
|
||||
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
|
||||
}
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
type expect struct {
|
||||
prompt string
|
||||
images [][]byte
|
||||
error error
|
||||
}
|
||||
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model Model
|
||||
limit int
|
||||
truncate bool
|
||||
msgs []api.Message
|
||||
expect
|
||||
}{
|
||||
{
|
||||
name: "messages",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages",
|
||||
model: visionModel,
|
||||
limit: 1,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with image",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with images",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]You're a test, Harry! I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with image tag",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate message with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 1024,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with system prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "out of order system",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple images same prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0][img-1]Compare these two pictures of hotdogs ",
|
||||
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no truncate with limit exceeded",
|
||||
model: visionModel,
|
||||
limit: 10,
|
||||
truncate: false,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if len(images) != len(tt.images) {
|
||||
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
if images[i].ID != i {
|
||||
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
||||
}
|
||||
|
||||
if len(model.Config.ModelFamilies) == 0 {
|
||||
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
||||
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
model := Model{Template: tmpl}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
maxTokenizes int
|
||||
}{
|
||||
{
|
||||
name: "all messages fit",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 1,
|
||||
},
|
||||
{
|
||||
name: "truncate to last message",
|
||||
limit: 5,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizeCount := 0
|
||||
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
||||
tokenizeCount++
|
||||
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tokenizeCount > tt.maxTokenizes {
|
||||
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "what do these photos have in common?",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
|
||||
},
|
||||
}
|
||||
originalContent := msgs[0].Content
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msgs[0].Content != originalContent {
|
||||
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
|
||||
}
|
||||
|
||||
if got, want := len(images), 3; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if prompt == "" {
|
||||
t.Fatal("prompt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "extract text",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1] extract text") {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererAddsToolImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "look at this file",
|
||||
Images: []api.ImageData{[]byte("img-1")},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_read",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "Read",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "attached image",
|
||||
Images: []api.ImageData{[]byte("img-2")},
|
||||
ToolCallID: "call_read",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer string
|
||||
wantUserTag string
|
||||
wantToolContent string
|
||||
}{
|
||||
{
|
||||
name: "gemma4",
|
||||
renderer: "gemma4",
|
||||
wantUserTag: "<|turn>user\n[img-0] look at this file<turn|>\n",
|
||||
wantToolContent: "[img-1] attached image",
|
||||
},
|
||||
{
|
||||
name: "qwen3-vl",
|
||||
renderer: "qwen3-vl-instruct",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "qwen3.5",
|
||||
renderer: "qwen3.5",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "glm-ocr",
|
||||
renderer: "glm-ocr",
|
||||
wantUserTag: "<|user|>\n[img-0] look at this file",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "nemotron-3-nano",
|
||||
renderer: "nemotron-3-nano",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantUserTag) {
|
||||
t.Fatalf("prompt missing user image tag, got: %q", prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantToolContent) {
|
||||
t.Fatalf("prompt missing tool image tag, got: %q", prompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererPreservesExplicitImagePlaceholders(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "compare [img] and [img]",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer string
|
||||
wantSnippet string
|
||||
}{
|
||||
{
|
||||
name: "gemma4",
|
||||
renderer: "gemma4",
|
||||
wantSnippet: "<|turn>user\ncompare [img-0] and [img-1]<turn|>\n",
|
||||
},
|
||||
{
|
||||
name: "qwen3-vl",
|
||||
renderer: "qwen3-vl-instruct",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
{
|
||||
name: "qwen3.5",
|
||||
renderer: "qwen3.5",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
{
|
||||
name: "glm-ocr",
|
||||
renderer: "glm-ocr",
|
||||
wantSnippet: "<|user|>\ncompare [img-0] and [img-1]",
|
||||
},
|
||||
{
|
||||
name: "nemotron-3-nano",
|
||||
renderer: "nemotron-3-nano",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantSnippet) {
|
||||
t.Fatalf("prompt missing replaced placeholders, got: %q", prompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "small from name",
|
||||
model: Model{
|
||||
Name: "gemma4:e4b",
|
||||
ShortName: "gemma4:e4b",
|
||||
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "large from model type",
|
||||
model: Model{
|
||||
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||
},
|
||||
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := renderPrompt(&tt.model, msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
414
server/quantization.go
Normal file
414
server/quantization.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
||||
|
||||
type quantizer struct {
|
||||
*os.File
|
||||
offset uint64
|
||||
from, to *fsggml.Tensor
|
||||
progressFn func(n uint64)
|
||||
}
|
||||
|
||||
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
quantize := q.from.Kind != q.to.Kind
|
||||
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
||||
if !quantize {
|
||||
n, err := io.Copy(w, sr)
|
||||
q.progressFn(q.from.Size())
|
||||
return n, err
|
||||
}
|
||||
data, err := io.ReadAll(sr)
|
||||
if err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
}
|
||||
if uint64(len(data)) < q.from.Size() {
|
||||
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||
}
|
||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
||||
n, err := w.Write(data)
|
||||
q.progressFn(q.from.Size())
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
nAttnV int // Number of attn_*v* weight tensors
|
||||
nFfnDown int // Number of ffn_down tensors
|
||||
iAttnV int // Running counter of number of attn_v tensors that have been processed
|
||||
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
|
||||
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
|
||||
preserveSourceFP8ToQ8 bool
|
||||
preserveSourceQ4 bool
|
||||
sourceFP8Tensors map[string]struct{}
|
||||
}
|
||||
|
||||
func useMoreBits(iLayer, nLayers int) bool {
|
||||
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
|
||||
switch {
|
||||
// Full attention
|
||||
case strings.HasSuffix(name, ".attn_q.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_k.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_v.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".attn_output.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// Linear attention (Gated Delta Net) after split
|
||||
case strings.HasSuffix(name, ".attn_qkv.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".attn_gate.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// SSM
|
||||
case strings.HasSuffix(name, ".ssm_ba.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_beta.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_alpha.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ssm_out.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
|
||||
// MoE experts + shared experts
|
||||
case strings.HasSuffix(name, ".ffn_down_exps.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_exps.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
|
||||
return fsggml.TensorTypeQ4_K, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func isLagunaGGUFRoutedExpertWeight(name string) bool {
|
||||
return strings.HasSuffix(name, ".weight") && (strings.Contains(name, "ffn_gate_exps") ||
|
||||
strings.Contains(name, "ffn_up_exps") ||
|
||||
strings.Contains(name, "ffn_down_exps"))
|
||||
}
|
||||
|
||||
func lagunaGGUFBlockIndex(name string) (int, bool) {
|
||||
if !strings.HasPrefix(name, "blk.") {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(strings.TrimPrefix(name, "blk."), ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return i, true
|
||||
}
|
||||
|
||||
func lagunaGGUFQuantization(name string, originalType, requestedType fsggml.TensorType, ftype fsggml.FileType, blockCount int) (fsggml.TensorType, bool) {
|
||||
if !isLagunaGGUFRoutedExpertWeight(name) {
|
||||
return originalType, false
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, ".ffn_down_exps.weight") {
|
||||
if i, ok := lagunaGGUFBlockIndex(name); ok && blockCount > 0 {
|
||||
switch ftype {
|
||||
case fsggml.FileTypeQ4_K_M:
|
||||
if requestedType != fsggml.TensorTypeQ8_0 && useMoreBits(i, blockCount) {
|
||||
return fsggml.TensorTypeQ6_K, true
|
||||
}
|
||||
case fsggml.FileTypeQ4_K_S:
|
||||
if requestedType != fsggml.TensorTypeQ8_0 && i < blockCount/8 {
|
||||
return fsggml.TensorTypeQ5_K, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return requestedType, true
|
||||
}
|
||||
|
||||
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
|
||||
// Ported from llama_tensor_get_type, removed unsupported quantization types
|
||||
nExperts := max(1, kv.Uint("expert_count", 0))
|
||||
if name == "output.weight" || name == "output_norm.weight" || (!qs.hasOutput && name == "token_embd.weight") {
|
||||
nx := shape[0]
|
||||
qk_k := newType.BlockSize()
|
||||
if nx%qk_k != 0 {
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if newType != fsggml.TensorTypeQ8_0 {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if strings.Contains(name, "attn_v.weight") {
|
||||
if newType != fsggml.TensorTypeQ8_0 && (ftype == fsggml.FileTypeQ4_K_M) &&
|
||||
useMoreBits(qs.iAttnV, qs.nAttnV) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
|
||||
// TODO
|
||||
// if (qs.model.type == LLM_TYPE_70B) {
|
||||
// // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
|
||||
// // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
|
||||
// // nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||
// if (newType == GGML_TYPE_Q3_K || newType == GGML_TYPE_Q4_K) newType = GGML_TYPE_Q5_K;
|
||||
// }
|
||||
|
||||
if nExperts == 8 {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
qs.iAttnV++
|
||||
} else if strings.Contains(name, "attn_k.weight") {
|
||||
if nExperts == 8 {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
} else if strings.Contains(name, "attn_k_b.weight") ||
|
||||
strings.Contains(name, "attn_v_b.weight") ||
|
||||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
|
||||
strings.Contains(name, "attn_q_a.weight") ||
|
||||
strings.Contains(name, "attn_q_b.weight") {
|
||||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
|
||||
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
|
||||
// expert alphabetically, so dense increments the counter and expert uses counter-1.
|
||||
var iLayer int
|
||||
if strings.Contains(name, "_exps") {
|
||||
if kv.Architecture() == "laguna" {
|
||||
goto finalize
|
||||
}
|
||||
iLayer = max(0, qs.iFfnDown-1)
|
||||
} else {
|
||||
iLayer = qs.iFfnDown
|
||||
qs.iFfnDown++
|
||||
}
|
||||
n_layer := qs.nFfnDown
|
||||
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
newType = fsggml.TensorTypeQ6_K
|
||||
}
|
||||
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if newType != fsggml.TensorTypeQ8_0 && nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(name, "attn_qkv.weight") {
|
||||
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
}
|
||||
|
||||
finalize:
|
||||
if newType.IsQuantized() {
|
||||
nx := shape[0]
|
||||
qk_k := newType.BlockSize()
|
||||
|
||||
// Check if first dimension is divisible by block size
|
||||
if nx%qk_k != 0 {
|
||||
// Store the original type for logging
|
||||
originalType := newType
|
||||
|
||||
// Select appropriate fallback based on original type
|
||||
switch newType {
|
||||
case fsggml.TensorTypeQ4_K:
|
||||
newType = fsggml.TensorTypeQ5_0
|
||||
case fsggml.TensorTypeQ5_K:
|
||||
newType = fsggml.TensorTypeQ5_1
|
||||
case fsggml.TensorTypeQ6_K:
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
|
||||
// Final check - if still incompatible, fall back to F16
|
||||
if nx%newType.BlockSize() != 0 {
|
||||
newType = fsggml.TensorTypeF16
|
||||
}
|
||||
|
||||
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
|
||||
nx, qk_k, originalType.String(), newType.String()))
|
||||
}
|
||||
}
|
||||
return newType
|
||||
}
|
||||
|
||||
func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType, progressFn func(n uint64)) error {
|
||||
kv := maps.Clone(orig.KV())
|
||||
kv["general.file_type"] = newFileType
|
||||
// kv["general.quantization_version"] = ggml.QuantizationVersion()
|
||||
qs := &quantizeState{
|
||||
sourceFP8Tensors: sourceFP8TensorSet(kv),
|
||||
}
|
||||
hasSourceFP8 := hasSourceFP8Tensors(kv)
|
||||
qs.preserveSourceFP8ToQ8 = hasSourceFP8 && newFileType == fsggml.FileTypeQ8_0
|
||||
qs.preserveSourceQ4 = hasSourceFP8 && slices.Contains([]fsggml.FileType{fsggml.FileTypeQ4_K_M, fsggml.FileTypeQ4_K_S}, newFileType)
|
||||
// Build up the quantize state so newType can adjust types
|
||||
layerCount := 0
|
||||
for k, l := range orig.Tensors().GroupLayers() {
|
||||
if strings.HasPrefix(k, "blk.") {
|
||||
layerCount++
|
||||
}
|
||||
for _, tensor := range l {
|
||||
if strings.Contains(tensor.Name, "attn_v.weight") ||
|
||||
strings.Contains(tensor.Name, "attn_qkv.weight") ||
|
||||
strings.Contains(tensor.Name, "attn_kv_b.weight") {
|
||||
qs.nAttnV++
|
||||
} else if tensor.Name == "output.weight" {
|
||||
qs.hasOutput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
qs.nFfnDown = layerCount
|
||||
|
||||
origTensors := orig.Tensors().Items()
|
||||
outputTensors := make([]*fsggml.Tensor, len(origTensors))
|
||||
for i, tensor := range origTensors {
|
||||
newType := newType(tensor, kv, qs, newFileType)
|
||||
newTensor := &fsggml.Tensor{
|
||||
Name: tensor.Name,
|
||||
Shape: tensor.Shape,
|
||||
Kind: uint32(newType),
|
||||
}
|
||||
outputTensors[i] = newTensor
|
||||
outputTensors[i].WriterTo = quantizer{
|
||||
File: in,
|
||||
offset: orig.Tensors().Offset + tensor.Offset,
|
||||
from: tensor,
|
||||
to: newTensor,
|
||||
progressFn: progressFn,
|
||||
}
|
||||
}
|
||||
return fsggml.WriteGGUF(out, kv, outputTensors)
|
||||
}
|
||||
|
||||
func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.FileType) fsggml.TensorType {
|
||||
defaultType := ftype.ToTensorType()
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision or audio encoder tensors
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
quantize = quantize && !strings.HasPrefix(name, "a.")
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize = quantize && (len(t.Shape) >= 2)
|
||||
|
||||
// do not quantize norm tensors
|
||||
quantize = quantize && !strings.Contains(name, "_norm.weight")
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
|
||||
quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
|
||||
|
||||
// do not quantize positional embeddings and token types (BERT)
|
||||
quantize = quantize && (name != "position_embd.weight")
|
||||
quantize = quantize && (name != "token_types.weight")
|
||||
|
||||
// do not quantize Mamba's small yet 2D weights
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
|
||||
|
||||
// do not quantize LFM2's shortconv kernel weights
|
||||
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w2.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_decay_w1.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_decay_w2.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_lerp_fused.weight")
|
||||
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
|
||||
|
||||
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
if qs.preserveSourceFP8ToQ8 {
|
||||
if _, ok := qs.sourceFP8Tensors[name]; !ok {
|
||||
return newType
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
|
||||
if qt, ok := qwen3LinearAttnQuantType(name); ok {
|
||||
return qt
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Consider extracting architecture-specific GGUF quantization policy
|
||||
// from server so different quantization backends can share one source of
|
||||
// truth for model-family specializations.
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if qs.preserveSourceQ4 {
|
||||
if _, ok := qs.sourceFP8Tensors[name]; !ok {
|
||||
defaultType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
}
|
||||
if kv.Architecture() == "laguna" {
|
||||
var ok bool
|
||||
defaultType, ok = lagunaGGUFQuantization(name, newType, defaultType, ftype, int(kv.Uint("block_count", 0)))
|
||||
if !ok {
|
||||
return newType
|
||||
}
|
||||
}
|
||||
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
|
||||
if newType != defaultType {
|
||||
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
|
||||
}
|
||||
}
|
||||
return newType
|
||||
}
|
||||
|
||||
func sourceFP8TensorSet(kv fsggml.KV) map[string]struct{} {
|
||||
names := kv.Strings("source_fp8_tensors")
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
out[name] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
834
server/quantization_test.go
Normal file
834
server/quantization_test.go
Normal file
@@ -0,0 +1,834 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
||||
|
||||
func TestGetTensorNewType(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
kv map[string]any
|
||||
qs quantizeState
|
||||
newType fsggml.TensorType
|
||||
tensor_name string
|
||||
shape []uint64
|
||||
ftype fsggml.FileType
|
||||
expected fsggml.TensorType
|
||||
expectedPanic string
|
||||
}{
|
||||
{
|
||||
name: "output_unsupported",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "output.weight",
|
||||
shape: []uint64{100, 100},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeF16,
|
||||
},
|
||||
{
|
||||
name: "output_Q8",
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "output.weight",
|
||||
shape: []uint64{1024, 1024},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q4_k_m",
|
||||
qs: quantizeState{
|
||||
iAttnV: 2,
|
||||
nAttnV: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_q4_k_s",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_S,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_v.weight_8_expert",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.expert_count": uint32(8),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_v.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
{
|
||||
name: "attn_k.weight_8_expert",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
"foo.expert_count": uint32(8),
|
||||
},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_k.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeF32,
|
||||
expected: fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_m",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 1,
|
||||
nFfnDown: 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_0,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_m_6",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
{
|
||||
name: "ffn_down_q4_k_s",
|
||||
qs: quantizeState{
|
||||
iFfnDown: 2,
|
||||
nFfnDown: 3 * 8,
|
||||
},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "ffn_down",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_S,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
{
|
||||
name: "attn_qkv.weight_q4_k_m",
|
||||
qs: quantizeState{},
|
||||
kv: map[string]any{},
|
||||
newType: fsggml.TensorTypeQ4_0,
|
||||
tensor_name: "blk.0.attn_qkv.weight",
|
||||
shape: []uint64{256},
|
||||
ftype: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectedPanic != "" {
|
||||
defer func() {
|
||||
e := recover()
|
||||
if !strings.Contains(fmt.Sprintf("%v", e), tt.expectedPanic) {
|
||||
t.Fatalf("incorrect panic\ngot: %v\nexpected: %s", e, tt.expectedPanic)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
defer func() {
|
||||
e := recover()
|
||||
if e != nil {
|
||||
t.Fatalf("hit unexpected panic %v", e)
|
||||
}
|
||||
}()
|
||||
}
|
||||
ret := getTensorNewType(tt.kv, &tt.qs, tt.newType, tt.tensor_name, tt.shape, tt.ftype)
|
||||
if ret != tt.expected {
|
||||
t.Fatalf("incorrect type returned\ngot: %d\nexpected: %d", ret, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
arch string
|
||||
tensor string
|
||||
fileType fsggml.FileType
|
||||
expected fsggml.TensorType
|
||||
}{
|
||||
{
|
||||
name: "qwen35_beta",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_beta.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35_alpha",
|
||||
arch: "qwen35",
|
||||
tensor: "blk.0.ssm_alpha.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "qwen35moe_attn_qkv",
|
||||
arch: "qwen35moe",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ4_K,
|
||||
},
|
||||
{
|
||||
name: "non_qwen35_falls_back",
|
||||
arch: "foo",
|
||||
tensor: "blk.0.attn_qkv.weight",
|
||||
fileType: fsggml.FileTypeQ4_K_M,
|
||||
expected: fsggml.TensorTypeQ5_K,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
kv := fsggml.KV{"general.architecture": tt.arch}
|
||||
got := newType(&fsggml.Tensor{
|
||||
Name: tt.tensor,
|
||||
Shape: []uint64{256, 256},
|
||||
Kind: uint32(fsggml.TensorTypeF16),
|
||||
}, kv, &quantizeState{}, tt.fileType)
|
||||
|
||||
if got != tt.expected {
|
||||
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
kv map[string]any
|
||||
tensors []*fsggml.Tensor
|
||||
newType string
|
||||
expectedTensorTypes map[string]fsggml.TensorType
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "f16_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 4},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn.weight": fsggml.TensorTypeQ4_K,
|
||||
"output.weight": fsggml.TensorTypeQ6_K,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_q4_k",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn_v.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512},
|
||||
WriterTo: bytes.NewReader(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn_v.weight": fsggml.TensorTypeQ6_K,
|
||||
"output.weight": fsggml.TensorTypeF32,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f16_q8_0",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{32, 16, 2},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 4},
|
||||
WriterTo: bytes.NewReader(
|
||||
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
|
||||
),
|
||||
},
|
||||
},
|
||||
newType: "Q8_0",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.0.attn.weight": fsggml.TensorTypeQ8_0,
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "source_fp8_q8_preserves_bf16_tensors",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"source_quantization": "hf_fp8",
|
||||
"source_fp8_tensors": []string{"blk.1.ffn_down_exps.weight"},
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
},
|
||||
newType: "Q8_0",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_q.weight": fsggml.TensorTypeBF16,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "source_fp8_q4_promotes_bf16_tensors_to_q8",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"source_quantization": "hf_fp8",
|
||||
"source_fp8_tensors": []string{
|
||||
"blk.1.ffn_gate_exps.weight",
|
||||
"blk.1.ffn_down_exps.weight",
|
||||
},
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.1.ffn_gate_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_v.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_down.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.attn_q_norm.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "blk.1.ffn_gate_inp.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
{
|
||||
Name: "output.weight", Kind: uint32(fsggml.TensorTypeBF16),
|
||||
Offset: uint64(0), Shape: []uint64{256, 1},
|
||||
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K_M",
|
||||
expectedTensorTypes: map[string]fsggml.TensorType{
|
||||
"blk.1.ffn_gate_exps.weight": fsggml.TensorTypeQ4_K,
|
||||
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ6_K,
|
||||
"blk.1.attn_q.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_v.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.ffn_down.weight": fsggml.TensorTypeQ8_0,
|
||||
"blk.1.attn_q_norm.weight": fsggml.TensorTypeBF16,
|
||||
"blk.1.ffn_gate_inp.weight": fsggml.TensorTypeBF16,
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "f16_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, _ := createBinFile(t, tt.kv, tt.tensors)
|
||||
fp, err := os.Open(p)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer fp.Close()
|
||||
meta, err := fsggml.Decode(fp, -1)
|
||||
if tt.expectErr && err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
progressCalled := false
|
||||
progress := func(n uint64) {
|
||||
// fmt.Fprintf(os.Stderr, "progress: %f\n", p)
|
||||
progressCalled = true
|
||||
}
|
||||
tmp, err := os.CreateTemp(t.TempDir(), tt.name+".out")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer tmp.Close()
|
||||
ftype, err := fsggml.ParseFileType(tt.newType)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = quantize(fp, tmp, meta, ftype, progress)
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected quantize to return an error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error during quantize: %s", err)
|
||||
}
|
||||
if !progressCalled {
|
||||
t.Fatalf("progress was not reported")
|
||||
}
|
||||
// Now attempt to load it back and make sure types match expected
|
||||
fpNew, err := os.Open(tmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||
}
|
||||
defer fpNew.Close()
|
||||
newMeta, err := fsggml.Decode(fpNew, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
|
||||
}
|
||||
tensors := newMeta.Tensors()
|
||||
for _, l := range tensors.GroupLayers() {
|
||||
for _, tensor := range l {
|
||||
if fsggml.TensorType(tensor.Kind) != tt.expectedTensorTypes[tensor.Name] {
|
||||
t.Fatalf("incorrect output type for %s\ngot:%s\nexpected:%s", tensor.Name, fsggml.TensorType(tensor.Kind), tt.expectedTensorTypes[tensor.Name])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToF32(t *testing.T) {
|
||||
expected := make([]float32, 256)
|
||||
for i := range expected {
|
||||
expected[i] = float32(i)
|
||||
}
|
||||
for dtype, data := range quantBytes {
|
||||
// Skip the no-op
|
||||
if dtype == fsggml.TensorTypeF32 {
|
||||
continue
|
||||
}
|
||||
t.Run(dtype.String(), func(t *testing.T) {
|
||||
fp32 := ggml.ConvertToF32(data, uint32(dtype), 256)
|
||||
similarity := cosineSimilarity(expected, fp32)
|
||||
if similarity < 0.999 {
|
||||
t.Fatalf("Results not similar enough: %s %f", dtype.String(), similarity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func dotProduct[V float32 | float64](v1, v2 []V) V {
|
||||
var result V = 0
|
||||
for i := range v1 {
|
||||
result += v1[i] * v2[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func magnitude[V float32 | float64](v []V) V {
|
||||
var result V = 0
|
||||
for _, val := range v {
|
||||
result += val * val
|
||||
}
|
||||
return V(math.Sqrt(float64(result)))
|
||||
}
|
||||
|
||||
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
|
||||
}
|
||||
|
||||
// Precomputed quantized data - arange 256
|
||||
// # For gguf-py supported types
|
||||
// import gguf
|
||||
// import numpy as np
|
||||
// print(repr(gguf.quantize(np.arange(256, dtype=np.float16), gguf.GGMLQuantizationType.Q4_0)))
|
||||
//
|
||||
// For types not supported by gguf-py converted via ggml_fp32_to_fp16_row and quantize_XXX
|
||||
//
|
||||
// data := make([]byte, 256*2)
|
||||
// fp32 := make([]float32, 256)
|
||||
// for i := range 256 {
|
||||
// fp32[i] = float32(i)
|
||||
// }
|
||||
// l := C.quantize_q6_K((*C.float)(&fp32[0]), unsafe.Pointer(&data[0]), 1, 256, nil)
|
||||
// for i := range data[:int(l)] {
|
||||
// fmt.Printf("%d, ", data[i])
|
||||
// }
|
||||
var (
|
||||
quantBytes = map[fsggml.TensorType][]byte{
|
||||
fsggml.TensorTypeQ4_0: {
|
||||
192, 195, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
|
||||
21, 21, 21, 4, 4, 224, 199, 36, 36, 36, 36, 19, 19,
|
||||
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 240, 201, 19,
|
||||
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
|
||||
1, 1, 240, 203, 18, 18, 18, 18, 18, 18, 18, 18, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 248, 204, 18, 18, 17, 17,
|
||||
17, 17, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 248,
|
||||
205, 17, 17, 17, 17, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 248, 206, 17, 17, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 248, 207, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1,
|
||||
},
|
||||
fsggml.TensorTypeQ4_1: {
|
||||
34, 64, 0, 0, 128, 128, 145, 145, 162, 162, 179, 179, 196,
|
||||
196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 80, 128, 128,
|
||||
145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230, 247,
|
||||
247, 34, 64, 0, 84, 128, 128, 145, 145, 162, 162, 179, 179,
|
||||
196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 86, 128,
|
||||
128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230,
|
||||
247, 247, 34, 64, 0, 88, 128, 128, 145, 145, 162, 162, 179,
|
||||
179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 89,
|
||||
128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230,
|
||||
230, 247, 247, 34, 64, 0, 90, 128, 128, 145, 145, 162, 162,
|
||||
179, 179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0,
|
||||
91, 128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213,
|
||||
230, 230, 247, 247,
|
||||
},
|
||||
fsggml.TensorTypeQ5_0: {
|
||||
192, 191, 1, 0, 0, 0, 128, 127, 127, 110, 110, 93, 93,
|
||||
76, 76, 59, 59, 42, 42, 25, 25, 8, 224, 195, 0, 0,
|
||||
0, 0, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
|
||||
21, 21, 21, 4, 4, 240, 197, 0, 0, 0, 0, 53, 37,
|
||||
37, 37, 37, 36, 36, 20, 20, 20, 20, 19, 19, 3, 3,
|
||||
3, 240, 199, 0, 0, 0, 0, 36, 36, 36, 36, 19, 19,
|
||||
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 248, 200, 0,
|
||||
0, 0, 0, 35, 19, 19, 19, 19, 19, 19, 18, 18, 18,
|
||||
18, 2, 2, 2, 2, 2, 248, 201, 0, 0, 0, 0, 19,
|
||||
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
|
||||
1, 1, 248, 202, 0, 0, 0, 0, 18, 18, 18, 18, 18,
|
||||
18, 18, 18, 18, 2, 2, 1, 1, 1, 1, 1, 248, 203,
|
||||
0, 0, 0, 0, 18, 18, 18, 18, 18, 18, 18, 18, 1,
|
||||
1, 1, 1, 1, 1, 1, 1,
|
||||
},
|
||||
fsggml.TensorTypeQ5_1: {
|
||||
0, 60, 0, 0, 0, 0, 255, 255, 0, 17, 34, 51, 68,
|
||||
85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60,
|
||||
0, 80, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102,
|
||||
119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 84,
|
||||
0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136,
|
||||
153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 86, 0, 0,
|
||||
255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170,
|
||||
187, 204, 221, 238, 255, 0, 60, 0, 88, 0, 0, 255, 255,
|
||||
0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204,
|
||||
221, 238, 255, 0, 60, 0, 89, 0, 0, 255, 255, 0, 17,
|
||||
34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238,
|
||||
255, 0, 60, 0, 90, 0, 0, 255, 255, 0, 17, 34, 51,
|
||||
68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0,
|
||||
60, 0, 91, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85,
|
||||
102, 119, 136, 153, 170, 187, 204, 221, 238, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ8_0: {
|
||||
208, 51, 0, 4, 8, 12, 16, 20, 25, 29, 33, 37, 41,
|
||||
45, 49, 53, 57, 61, 66, 70, 74, 78, 82, 86, 90, 94,
|
||||
98, 102, 107, 111, 115, 119, 123, 127, 240, 55, 65, 67, 69,
|
||||
71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95,
|
||||
97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121,
|
||||
123, 125, 127, 252, 57, 86, 87, 88, 90, 91, 92, 94, 95,
|
||||
96, 98, 99, 100, 102, 103, 104, 106, 107, 108, 110, 111, 112,
|
||||
114, 115, 116, 118, 119, 120, 122, 123, 124, 126, 127, 0, 60,
|
||||
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
|
||||
109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
|
||||
122, 123, 124, 125, 126, 127, 2, 61, 102, 103, 104, 105, 105,
|
||||
106, 107, 108, 109, 109, 110, 111, 112, 113, 113, 114, 115, 116,
|
||||
117, 117, 118, 119, 120, 121, 121, 122, 123, 124, 125, 125, 126,
|
||||
127, 4, 62, 106, 107, 108, 108, 109, 110, 110, 111, 112, 112,
|
||||
113, 114, 114, 115, 116, 116, 117, 118, 118, 119, 120, 120, 121,
|
||||
122, 122, 123, 124, 124, 125, 126, 126, 127, 6, 63, 109, 110,
|
||||
110, 111, 112, 112, 113, 113, 114, 114, 115, 116, 116, 117, 117,
|
||||
118, 118, 119, 120, 120, 121, 121, 122, 122, 123, 124, 124, 125,
|
||||
125, 126, 126, 127, 4, 64, 112, 112, 113, 113, 114, 114, 115,
|
||||
115, 116, 116, 117, 117, 118, 118, 119, 119, 120, 120, 121, 121,
|
||||
122, 122, 123, 123, 124, 124, 125, 125, 126, 126, 127, 127,
|
||||
},
|
||||
fsggml.TensorTypeBF16: {
|
||||
0, 0, 128, 63, 0, 64, 64, 64, 128, 64, 160, 64, 192,
|
||||
64, 224, 64, 0, 65, 16, 65, 32, 65, 48, 65, 64, 65,
|
||||
80, 65, 96, 65, 112, 65, 128, 65, 136, 65, 144, 65, 152,
|
||||
65, 160, 65, 168, 65, 176, 65, 184, 65, 192, 65, 200, 65,
|
||||
208, 65, 216, 65, 224, 65, 232, 65, 240, 65, 248, 65, 0,
|
||||
66, 4, 66, 8, 66, 12, 66, 16, 66, 20, 66, 24, 66,
|
||||
28, 66, 32, 66, 36, 66, 40, 66, 44, 66, 48, 66, 52,
|
||||
66, 56, 66, 60, 66, 64, 66, 68, 66, 72, 66, 76, 66,
|
||||
80, 66, 84, 66, 88, 66, 92, 66, 96, 66, 100, 66, 104,
|
||||
66, 108, 66, 112, 66, 116, 66, 120, 66, 124, 66, 128, 66,
|
||||
130, 66, 132, 66, 134, 66, 136, 66, 138, 66, 140, 66, 142,
|
||||
66, 144, 66, 146, 66, 148, 66, 150, 66, 152, 66, 154, 66,
|
||||
156, 66, 158, 66, 160, 66, 162, 66, 164, 66, 166, 66, 168,
|
||||
66, 170, 66, 172, 66, 174, 66, 176, 66, 178, 66, 180, 66,
|
||||
182, 66, 184, 66, 186, 66, 188, 66, 190, 66, 192, 66, 194,
|
||||
66, 196, 66, 198, 66, 200, 66, 202, 66, 204, 66, 206, 66,
|
||||
208, 66, 210, 66, 212, 66, 214, 66, 216, 66, 218, 66, 220,
|
||||
66, 222, 66, 224, 66, 226, 66, 228, 66, 230, 66, 232, 66,
|
||||
234, 66, 236, 66, 238, 66, 240, 66, 242, 66, 244, 66, 246,
|
||||
66, 248, 66, 250, 66, 252, 66, 254, 66, 0, 67, 1, 67,
|
||||
2, 67, 3, 67, 4, 67, 5, 67, 6, 67, 7, 67, 8,
|
||||
67, 9, 67, 10, 67, 11, 67, 12, 67, 13, 67, 14, 67,
|
||||
15, 67, 16, 67, 17, 67, 18, 67, 19, 67, 20, 67, 21,
|
||||
67, 22, 67, 23, 67, 24, 67, 25, 67, 26, 67, 27, 67,
|
||||
28, 67, 29, 67, 30, 67, 31, 67, 32, 67, 33, 67, 34,
|
||||
67, 35, 67, 36, 67, 37, 67, 38, 67, 39, 67, 40, 67,
|
||||
41, 67, 42, 67, 43, 67, 44, 67, 45, 67, 46, 67, 47,
|
||||
67, 48, 67, 49, 67, 50, 67, 51, 67, 52, 67, 53, 67,
|
||||
54, 67, 55, 67, 56, 67, 57, 67, 58, 67, 59, 67, 60,
|
||||
67, 61, 67, 62, 67, 63, 67, 64, 67, 65, 67, 66, 67,
|
||||
67, 67, 68, 67, 69, 67, 70, 67, 71, 67, 72, 67, 73,
|
||||
67, 74, 67, 75, 67, 76, 67, 77, 67, 78, 67, 79, 67,
|
||||
80, 67, 81, 67, 82, 67, 83, 67, 84, 67, 85, 67, 86,
|
||||
67, 87, 67, 88, 67, 89, 67, 90, 67, 91, 67, 92, 67,
|
||||
93, 67, 94, 67, 95, 67, 96, 67, 97, 67, 98, 67, 99,
|
||||
67, 100, 67, 101, 67, 102, 67, 103, 67, 104, 67, 105, 67,
|
||||
106, 67, 107, 67, 108, 67, 109, 67, 110, 67, 111, 67, 112,
|
||||
67, 113, 67, 114, 67, 115, 67, 116, 67, 117, 67, 118, 67,
|
||||
119, 67, 120, 67, 121, 67, 122, 67, 123, 67, 124, 67, 125,
|
||||
67, 126, 67, 127, 67,
|
||||
},
|
||||
fsggml.TensorTypeF16: {
|
||||
0, 0, 0, 60, 0, 64, 0, 66, 0, 68, 0, 69, 0, 70, 0, 71, 0,
|
||||
72, 128, 72, 0, 73, 128, 73, 0, 74, 128, 74, 0, 75, 128, 75,
|
||||
0, 76, 64, 76, 128, 76, 192, 76, 0, 77, 64, 77, 128, 77, 192,
|
||||
77, 0, 78, 64, 78, 128, 78, 192, 78, 0, 79, 64, 79, 128, 79,
|
||||
192, 79, 0, 80, 32, 80, 64, 80, 96, 80, 128, 80, 160, 80,
|
||||
192, 80, 224, 80, 0, 81, 32, 81, 64, 81, 96, 81, 128, 81,
|
||||
160, 81, 192, 81, 224, 81, 0, 82, 32, 82, 64, 82, 96, 82,
|
||||
128, 82, 160, 82, 192, 82, 224, 82, 0, 83, 32, 83, 64, 83,
|
||||
96, 83, 128, 83, 160, 83, 192, 83, 224, 83, 0, 84, 16, 84,
|
||||
32, 84, 48, 84, 64, 84, 80, 84, 96, 84, 112, 84, 128, 84,
|
||||
144, 84, 160, 84, 176, 84, 192, 84, 208, 84, 224, 84, 240,
|
||||
84, 0, 85, 16, 85, 32, 85, 48, 85, 64, 85, 80, 85, 96, 85,
|
||||
112, 85, 128, 85, 144, 85, 160, 85, 176, 85, 192, 85, 208,
|
||||
85, 224, 85, 240, 85, 0, 86, 16, 86, 32, 86, 48, 86, 64,
|
||||
86, 80, 86, 96, 86, 112, 86, 128, 86, 144, 86, 160, 86,
|
||||
176, 86, 192, 86, 208, 86, 224, 86, 240, 86, 0, 87, 16,
|
||||
87, 32, 87, 48, 87, 64, 87, 80, 87, 96, 87, 112, 87, 128,
|
||||
87, 144, 87, 160, 87, 176, 87, 192, 87, 208, 87, 224, 87,
|
||||
240, 87, 0, 88, 8, 88, 16, 88, 24, 88, 32, 88, 40, 88,
|
||||
48, 88, 56, 88, 64, 88, 72, 88, 80, 88, 88, 88, 96, 88,
|
||||
104, 88, 112, 88, 120, 88, 128, 88, 136, 88, 144, 88, 152,
|
||||
88, 160, 88, 168, 88, 176, 88, 184, 88, 192, 88, 200, 88,
|
||||
208, 88, 216, 88, 224, 88, 232, 88, 240, 88, 248, 88, 0,
|
||||
89, 8, 89, 16, 89, 24, 89, 32, 89, 40, 89, 48, 89, 56, 89,
|
||||
64, 89, 72, 89, 80, 89, 88, 89, 96, 89, 104, 89, 112, 89,
|
||||
120, 89, 128, 89, 136, 89, 144, 89, 152, 89, 160, 89, 168,
|
||||
89, 176, 89, 184, 89, 192, 89, 200, 89, 208, 89, 216, 89,
|
||||
224, 89, 232, 89, 240, 89, 248, 89, 0, 90, 8, 90, 16, 90,
|
||||
24, 90, 32, 90, 40, 90, 48, 90, 56, 90, 64, 90, 72, 90, 80,
|
||||
90, 88, 90, 96, 90, 104, 90, 112, 90, 120, 90, 128, 90,
|
||||
136, 90, 144, 90, 152, 90, 160, 90, 168, 90, 176, 90, 184,
|
||||
90, 192, 90, 200, 90, 208, 90, 216, 90, 224, 90, 232, 90,
|
||||
240, 90, 248, 90, 0, 91, 8, 91, 16, 91, 24, 91, 32, 91, 40,
|
||||
91, 48, 91, 56, 91, 64, 91, 72, 91, 80, 91, 88, 91, 96, 91,
|
||||
104, 91, 112, 91, 120, 91, 128, 91, 136, 91, 144, 91, 152,
|
||||
91, 160, 91, 168, 91, 176, 91, 184, 91, 192, 91, 200, 91,
|
||||
208, 91, 216, 91, 224, 91, 232, 91, 240, 91, 248, 91,
|
||||
},
|
||||
fsggml.TensorTypeF32: {
|
||||
0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128,
|
||||
64, 0, 0, 160, 64, 0, 0, 192, 64, 0, 0, 224, 64, 0, 0, 0, 65, 0,
|
||||
0, 16, 65, 0, 0, 32, 65, 0, 0, 48, 65, 0, 0, 64, 65, 0, 0, 80, 65,
|
||||
0, 0, 96, 65, 0, 0, 112, 65, 0, 0, 128, 65, 0, 0, 136, 65, 0, 0,
|
||||
144, 65, 0, 0, 152, 65, 0, 0, 160, 65, 0, 0, 168, 65, 0, 0, 176,
|
||||
65, 0, 0, 184, 65, 0, 0, 192, 65, 0, 0, 200, 65, 0, 0, 208, 65, 0,
|
||||
0, 216, 65, 0, 0, 224, 65, 0, 0, 232, 65, 0, 0, 240, 65, 0, 0, 248,
|
||||
65, 0, 0, 0, 66, 0, 0, 4, 66, 0, 0, 8, 66, 0, 0, 12, 66, 0, 0, 16,
|
||||
66, 0, 0, 20, 66, 0, 0, 24, 66, 0, 0, 28, 66, 0, 0, 32, 66, 0, 0,
|
||||
36, 66, 0, 0, 40, 66, 0, 0, 44, 66, 0, 0, 48, 66, 0, 0, 52, 66, 0,
|
||||
0, 56, 66, 0, 0, 60, 66, 0, 0, 64, 66, 0, 0, 68, 66, 0, 0, 72, 66,
|
||||
0, 0, 76, 66, 0, 0, 80, 66, 0, 0, 84, 66, 0, 0, 88, 66, 0, 0, 92, 66,
|
||||
0, 0, 96, 66, 0, 0, 100, 66, 0, 0, 104, 66, 0, 0, 108, 66, 0, 0, 112,
|
||||
66, 0, 0, 116, 66, 0, 0, 120, 66, 0, 0, 124, 66, 0, 0, 128, 66, 0, 0,
|
||||
130, 66, 0, 0, 132, 66, 0, 0, 134, 66, 0, 0, 136, 66, 0, 0, 138, 66,
|
||||
0, 0, 140, 66, 0, 0, 142, 66, 0, 0, 144, 66, 0, 0, 146, 66, 0, 0, 148,
|
||||
66, 0, 0, 150, 66, 0, 0, 152, 66, 0, 0, 154, 66, 0, 0, 156, 66, 0, 0,
|
||||
158, 66, 0, 0, 160, 66, 0, 0, 162, 66, 0, 0, 164, 66, 0, 0, 166, 66,
|
||||
0, 0, 168, 66, 0, 0, 170, 66, 0, 0, 172, 66, 0, 0, 174, 66, 0, 0, 176,
|
||||
66, 0, 0, 178, 66, 0, 0, 180, 66, 0, 0, 182, 66, 0, 0, 184, 66, 0, 0,
|
||||
186, 66, 0, 0, 188, 66, 0, 0, 190, 66, 0, 0, 192, 66, 0, 0, 194, 66, 0,
|
||||
0, 196, 66, 0, 0, 198, 66, 0, 0, 200, 66, 0, 0, 202, 66, 0, 0, 204, 66,
|
||||
0, 0, 206, 66, 0, 0, 208, 66, 0, 0, 210, 66, 0, 0, 212, 66, 0, 0, 214, 66,
|
||||
0, 0, 216, 66, 0, 0, 218, 66, 0, 0, 220, 66, 0, 0, 222, 66, 0, 0, 224, 66,
|
||||
0, 0, 226, 66, 0, 0, 228, 66, 0, 0, 230, 66, 0, 0, 232, 66, 0, 0, 234, 66,
|
||||
0, 0, 236, 66, 0, 0, 238, 66, 0, 0, 240, 66, 0, 0, 242, 66, 0, 0, 244, 66,
|
||||
0, 0, 246, 66, 0, 0, 248, 66, 0, 0, 250, 66, 0, 0, 252, 66, 0, 0, 254, 66,
|
||||
0, 0, 0, 67, 0, 0, 1, 67, 0, 0, 2, 67, 0, 0, 3, 67, 0, 0, 4, 67, 0, 0, 5, 67,
|
||||
0, 0, 6, 67, 0, 0, 7, 67, 0, 0, 8, 67, 0, 0, 9, 67, 0, 0, 10, 67, 0, 0, 11,
|
||||
67, 0, 0, 12, 67, 0, 0, 13, 67, 0, 0, 14, 67, 0, 0, 15, 67, 0, 0, 16, 67,
|
||||
0, 0, 17, 67, 0, 0, 18, 67, 0, 0, 19, 67, 0, 0, 20, 67, 0, 0, 21, 67, 0, 0,
|
||||
22, 67, 0, 0, 23, 67, 0, 0, 24, 67, 0, 0, 25, 67, 0, 0, 26, 67, 0, 0, 27,
|
||||
67, 0, 0, 28, 67, 0, 0, 29, 67, 0, 0, 30, 67, 0, 0, 31, 67, 0, 0, 32, 67,
|
||||
0, 0, 33, 67, 0, 0, 34, 67, 0, 0, 35, 67, 0, 0, 36, 67, 0, 0, 37, 67, 0, 0,
|
||||
38, 67, 0, 0, 39, 67, 0, 0, 40, 67, 0, 0, 41, 67, 0, 0, 42, 67, 0, 0, 43, 67,
|
||||
0, 0, 44, 67, 0, 0, 45, 67, 0, 0, 46, 67, 0, 0, 47, 67, 0, 0, 48, 67, 0, 0,
|
||||
49, 67, 0, 0, 50, 67, 0, 0, 51, 67, 0, 0, 52, 67, 0, 0, 53, 67, 0, 0, 54, 67,
|
||||
0, 0, 55, 67, 0, 0, 56, 67, 0, 0, 57, 67, 0, 0, 58, 67, 0, 0, 59, 67, 0, 0,
|
||||
60, 67, 0, 0, 61, 67, 0, 0, 62, 67, 0, 0, 63, 67, 0, 0, 64, 67, 0, 0, 65, 67,
|
||||
0, 0, 66, 67, 0, 0, 67, 67, 0, 0, 68, 67, 0, 0, 69, 67, 0, 0, 70, 67, 0, 0, 71,
|
||||
67, 0, 0, 72, 67, 0, 0, 73, 67, 0, 0, 74, 67, 0, 0, 75, 67, 0, 0, 76, 67, 0,
|
||||
0, 77, 67, 0, 0, 78, 67, 0, 0, 79, 67, 0, 0, 80, 67, 0, 0, 81, 67, 0, 0, 82,
|
||||
67, 0, 0, 83, 67, 0, 0, 84, 67, 0, 0, 85, 67, 0, 0, 86, 67, 0, 0, 87, 67, 0,
|
||||
0, 88, 67, 0, 0, 89, 67, 0, 0, 90, 67, 0, 0, 91, 67, 0, 0, 92, 67, 0, 0, 93,
|
||||
67, 0, 0, 94, 67, 0, 0, 95, 67, 0, 0, 96, 67, 0, 0, 97, 67, 0, 0, 98, 67, 0,
|
||||
0, 99, 67, 0, 0, 100, 67, 0, 0, 101, 67, 0, 0, 102, 67, 0, 0, 103, 67, 0, 0,
|
||||
104, 67, 0, 0, 105, 67, 0, 0, 106, 67, 0, 0, 107, 67, 0, 0, 108, 67, 0, 0, 109,
|
||||
67, 0, 0, 110, 67, 0, 0, 111, 67, 0, 0, 112, 67, 0, 0, 113, 67, 0, 0, 114, 67,
|
||||
0, 0, 115, 67, 0, 0, 116, 67, 0, 0, 117, 67, 0, 0, 118, 67, 0, 0, 119, 67, 0,
|
||||
0, 120, 67, 0, 0, 121, 67, 0, 0, 122, 67, 0, 0, 123, 67, 0, 0, 124, 67, 0, 0,
|
||||
125, 67, 0, 0, 126, 67, 0, 0, 127, 67,
|
||||
},
|
||||
fsggml.TensorTypeQ4_K: {
|
||||
52, 52, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 8, 15, 128,
|
||||
128, 129, 129, 146, 146, 147, 147, 164, 164, 165, 165, 166, 182,
|
||||
183, 183, 184, 200, 201, 201, 202, 218, 218, 219, 219, 236, 236,
|
||||
237, 237, 254, 254, 255, 202, 202, 202, 203, 203, 203, 219, 219,
|
||||
219, 220, 220, 220, 220, 220, 236, 237, 237, 237, 237, 237,
|
||||
237, 237, 238, 254, 254, 254, 254, 254, 255, 255, 255, 255, 220,
|
||||
220, 220, 220, 221, 221, 221, 221, 221, 221, 221, 237, 237, 237,
|
||||
238, 238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 255, 255,
|
||||
255, 255, 255, 255, 255, 237, 237, 237, 237, 237, 237, 237, 238,
|
||||
238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 254, 254, 254,
|
||||
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ2_K: {
|
||||
1, 2, 3, 3, 4, 5, 7, 7, 8, 9, 10, 11, 12, 13, 14, 15, 184, 184,
|
||||
184, 185, 249, 249, 249, 249, 249, 250, 250, 254, 254, 254, 254,
|
||||
255, 253, 253, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
|
||||
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 171, 69, 0, 0,
|
||||
},
|
||||
fsggml.TensorTypeQ5_K: {
|
||||
32, 48, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 7, 15, 254,
|
||||
254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
|
||||
254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 0, 1, 2, 19, 20, 37, 38, 55, 56, 73, 74,
|
||||
91, 92, 109, 110, 127, 112, 128, 129, 146, 147, 164, 165, 182, 183,
|
||||
200, 201, 218, 219, 236, 237, 254, 133, 133, 149, 150, 150, 150,
|
||||
167, 167, 167, 168, 184, 184, 185, 185, 201, 202, 202, 202, 219,
|
||||
219, 219, 219, 236, 236, 236, 237, 253, 253, 254, 254, 254, 255,
|
||||
169, 169, 169, 169, 186, 186, 186, 186, 186, 187, 187, 203, 203,
|
||||
203, 204, 204, 204, 220, 220, 221, 221, 221, 221, 237, 237, 238,
|
||||
238, 238, 238, 254, 255, 255, 203, 203, 203, 204, 204, 204, 204,
|
||||
204, 220, 220, 220, 221, 221, 221, 221, 221, 237, 237, 238, 238,
|
||||
238, 238, 238, 238, 254, 255, 255, 255, 255, 255, 255, 255,
|
||||
},
|
||||
fsggml.TensorTypeQ6_K: {
|
||||
96, 110, 92, 90, 88, 70, 68, 50, 48, 46, 44, 42, 24, 22, 4, 2, 80,
|
||||
95, 78, 77, 76, 59, 58, 57, 40, 39, 38, 21, 20, 19, 2, 1, 75, 75,
|
||||
74, 57, 57, 56, 55, 39, 38, 37, 21, 20, 20, 19, 2, 2, 72, 55, 55,
|
||||
54, 54, 37, 37, 36, 36, 19, 19, 18, 18, 1, 1, 0, 35, 35, 35, 35,
|
||||
34, 18, 18, 18, 17, 17, 17, 1, 1, 0, 0, 0, 35, 35, 34, 34, 18,
|
||||
18, 18, 17, 17, 17, 17, 1, 0, 0, 0, 0, 35, 35, 35, 19, 19, 18, 18,
|
||||
18, 18, 18, 1, 1, 1, 1, 1, 1, 34, 34, 18, 18, 18, 18, 17, 17, 17,
|
||||
17, 1, 1, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 248, 240, 231, 224, 216, 208, 200, 192, 184, 176,
|
||||
166, 160, 152, 144, 136, 128, 235, 43,
|
||||
},
|
||||
fsggml.TensorTypeQ3_K: {
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 20, 23, 23, 7, 7, 6, 6, 6, 2,
|
||||
1, 1, 1, 1, 0, 0, 22, 22, 6, 6, 5, 5, 5, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 204, 170, 136, 102, 68,
|
||||
34, 1, 5, 5, 5, 5, 189, 63,
|
||||
},
|
||||
}
|
||||
)
|
||||
110
server/renderer_resolution.go
Normal file
110
server/renderer_resolution.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
gemma4RendererLegacy = "gemma4"
|
||||
gemma4RendererSmall = "gemma4-small"
|
||||
gemma4RendererLarge = "gemma4-large"
|
||||
|
||||
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
|
||||
// large template. Default to the small prompt unless the model is clearly in
|
||||
// the large range.
|
||||
gemma4LargeMinParameterCount = 16_000_000_000
|
||||
)
|
||||
|
||||
func resolveRendererName(m *Model) string {
|
||||
if m == nil || m.Config.Renderer == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch m.Config.Renderer {
|
||||
case gemma4RendererLegacy:
|
||||
return resolveGemma4Renderer(m)
|
||||
default:
|
||||
return m.Config.Renderer
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGemma4Renderer(m *Model) string {
|
||||
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
|
||||
if m == nil {
|
||||
return gemma4RendererLegacy
|
||||
}
|
||||
return m.Config.Renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if renderer, ok := gemma4RendererFromName(m.Name); ok {
|
||||
return renderer
|
||||
}
|
||||
|
||||
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
|
||||
return gemma4RendererForParameterCount(parameterCount)
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererForParameterCount(parameterCount uint64) string {
|
||||
if parameterCount >= gemma4LargeMinParameterCount {
|
||||
return gemma4RendererLarge
|
||||
}
|
||||
|
||||
return gemma4RendererSmall
|
||||
}
|
||||
|
||||
func gemma4RendererFromName(name string) (string, bool) {
|
||||
lower := strings.ToLower(name)
|
||||
switch {
|
||||
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
|
||||
return gemma4RendererSmall, true
|
||||
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
|
||||
return gemma4RendererLarge, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parseHumanParameterCount(s string) (uint64, bool) {
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
unit := strings.ToUpper(s[len(s)-1:])
|
||||
var multiplier float64
|
||||
switch unit {
|
||||
case "B":
|
||||
multiplier = float64(format.Billion)
|
||||
case "M":
|
||||
multiplier = float64(format.Million)
|
||||
case "K":
|
||||
multiplier = float64(format.Thousand)
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
|
||||
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return uint64(value * multiplier), true
|
||||
}
|
||||
|
||||
func isGemma4Renderer(renderer string) bool {
|
||||
switch renderer {
|
||||
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
2842
server/routes.go
Normal file
2842
server/routes.go
Normal file
File diff suppressed because it is too large
Load Diff
1147
server/routes_cloud_test.go
Normal file
1147
server/routes_cloud_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1506
server/routes_create_test.go
Normal file
1506
server/routes_create_test.go
Normal file
File diff suppressed because it is too large
Load Diff
415
server/routes_debug_test.go
Normal file
415
server/routes_debug_test.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.GenerateRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "debug render only enabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "debug render only disabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "debug render only with system prompt",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "User question",
|
||||
System: "You are a helpful assistant",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "User question",
|
||||
},
|
||||
{
|
||||
name: "debug render only with template",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Template: "PROMPT: {{ .Prompt }}",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "PROMPT: Hello",
|
||||
},
|
||||
{
|
||||
name: "debug render only with images",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[img-0]Describe this image",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "debug render only with raw mode",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Raw prompt text",
|
||||
Raw: true,
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Raw prompt text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.GenerateHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.ChatRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "chat debug render only enabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "system: You are a helpful assistant\nuser: Hello\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug render only disabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "chat debug with assistant message",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: Hello\nassistant: Hi there!\nuser: How are you?\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug with images",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's in this image?",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: [img-0]What's in this image?\n",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "chat debug with tools",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Get the weather"},
|
||||
},
|
||||
Tools: api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]user: Get the weather\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
142
server/routes_delete_test.go
Normal file
142
server/routes_delete_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ .System }} {{ .Prompt }}",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test2"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
n := model.ParseName("test")
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&model.ConfigV2{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "gpt-oss:20b-cloud",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
315
server/routes_generate_renderer_test.go
Normal file
315
server/routes_generate_renderer_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
|
||||
// when in chat-like flow (messages present, no suffix, no template)
|
||||
func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer (qwen3-coder)
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Create a model with the qwen3-coder renderer
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
|
||||
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
||||
// Test that when using messages (chat-like flow), the built-in renderer is used
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
||||
// When messages are built internally from prompt, it should use the renderer
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
|
||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
|
||||
// Test that system messages work with the renderer
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a helpful coding assistant.",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should contain the system message and use renderer format
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
|
||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
|
||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom template bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a custom template uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when custom template is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should just be the raw prompt from the template
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a model with suffix support for the next test
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
From: "test-renderer",
|
||||
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||
{{- else }}{{ .Prompt }}
|
||||
{{- end }}`,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("suffix bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a suffix uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
Prompt: "def add(",
|
||||
Suffix: " return c",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when suffix is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should use the suffix template format
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
|
||||
func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("debug_render_only with renderer", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a coding assistant",
|
||||
DebugRenderOnly: true,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.DebugInfo == nil {
|
||||
t.Fatalf("expected debug info, got nil")
|
||||
}
|
||||
|
||||
// Verify that the rendered template uses the built-in renderer
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
|
||||
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
|
||||
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
|
||||
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
})
|
||||
}
|
||||
2808
server/routes_generate_test.go
Normal file
2808
server/routes_generate_test.go
Normal file
File diff suppressed because it is too large
Load Diff
683
server/routes_harmony_streaming_test.go
Normal file
683
server/routes_harmony_streaming_test.go
Normal file
@@ -0,0 +1,683 @@
|
||||
package server
|
||||
|
||||
// this test file is to test integration of harmony parser into routes.go (as
|
||||
// opposed to harmonyparser_test.go, which tests the parser in isolation)
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func getTestTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "calculate",
|
||||
Description: "Calculate a mathematical expression",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"expression"},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The mathematical expression to calculate",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createHarmonyTestModel(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
|
||||
return createBinFile(t, ggml.KV{
|
||||
"general.architecture": "gptoss",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
}
|
||||
|
||||
// TestChatHarmonyParserStreamingRealtime verifies that chunks are emitted as soon as they're available
|
||||
func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type step struct {
|
||||
input llm.CompletionResponse
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls []api.ToolCall
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
name: "content streams as it arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||
wantContent: "Hello",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: ", world", Done: false},
|
||||
wantContent: ", world",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "!",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking streams separately from content",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||
wantThinking: "Thinking...",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||
// No output expected - just closes the analysis message and resets state to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||
wantContent: "Answer", // After message end, state is reset to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
// No output expected - just closes the assistant message
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tags buffer until complete",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||
// No output - partial tag
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||
// No output - still building tags
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||
wantThinking: "Deep ",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||
wantThinking: "thought",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done", // After message end, state is reset to normal
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple assistant after analysis",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Answer",
|
||||
wantThinking: "Think",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call parsed and returned correctly",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "The weather is sunny",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with streaming JSON across chunks",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||
// No output yet - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||
// Still no output - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"expression": "2+2",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range testCases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var chunks []api.ChatResponse
|
||||
chunkIdx := 0
|
||||
|
||||
mockResponses := make([]llm.CompletionResponse, len(tc.steps))
|
||||
for i, step := range tc.steps {
|
||||
mockResponses[i] = step.input
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
for _, resp := range mockResponses {
|
||||
fn(resp)
|
||||
// Give the handler time to process each response
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a simple test model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse all chunks
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
if chunk.Message.Content != "" || chunk.Message.Thinking != "" || len(chunk.Message.ToolCalls) > 0 {
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Log received chunks for debugging
|
||||
if t.Failed() || len(chunks) == 0 {
|
||||
t.Logf("Received %d chunks:", len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
t.Logf(" Chunk %d: content=%q thinking=%q", i, chunk.Message.Content, chunk.Message.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify chunks match expected steps
|
||||
for i, step := range tc.steps {
|
||||
// Skip steps that don't expect any output
|
||||
if step.wantContent == "" && step.wantThinking == "" && len(step.wantToolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if chunkIdx >= len(chunks) {
|
||||
t.Errorf("step %d: expected chunk not received (wanted content=%q thinking=%q)",
|
||||
i, step.wantContent, step.wantThinking)
|
||||
continue
|
||||
}
|
||||
|
||||
chunk := chunks[chunkIdx]
|
||||
if chunk.Message.Content != step.wantContent || chunk.Message.Thinking != step.wantThinking {
|
||||
t.Errorf("step %d: chunk mismatch: got (content=%q, thinking=%q), want (content=%q, thinking=%q)",
|
||||
i, chunk.Message.Content, chunk.Message.Thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
|
||||
// Check tool calls if expected
|
||||
if len(step.wantToolCalls) > 0 {
|
||||
if len(chunk.Message.ToolCalls) != len(step.wantToolCalls) {
|
||||
t.Errorf("step %d: tool calls count mismatch: got %d, want %d",
|
||||
i, len(chunk.Message.ToolCalls), len(step.wantToolCalls))
|
||||
} else {
|
||||
for j, wantCall := range step.wantToolCalls {
|
||||
if j >= len(chunk.Message.ToolCalls) {
|
||||
break
|
||||
}
|
||||
gotCall := chunk.Message.ToolCalls[j]
|
||||
if gotCall.Function.Name != wantCall.Function.Name {
|
||||
t.Errorf("step %d, tool call %d: name mismatch: got %q, want %q",
|
||||
i, j, gotCall.Function.Name, wantCall.Function.Name)
|
||||
}
|
||||
// Compare arguments as JSON strings for simplicity
|
||||
gotArgs, _ := json.Marshal(gotCall.Function.Arguments)
|
||||
wantArgs, _ := json.Marshal(wantCall.Function.Arguments)
|
||||
if string(gotArgs) != string(wantArgs) {
|
||||
t.Errorf("step %d, tool call %d: arguments mismatch: got %s, want %s",
|
||||
i, j, string(gotArgs), string(wantArgs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
chunkIdx++
|
||||
}
|
||||
|
||||
// Check if we have extra chunks
|
||||
if chunkIdx < len(chunks) {
|
||||
t.Errorf("received %d extra chunks", len(chunks)-chunkIdx)
|
||||
for i := chunkIdx; i < len(chunks); i++ {
|
||||
t.Logf(" extra chunk %d: content=%q thinking=%q",
|
||||
i-chunkIdx, chunks[i].Message.Content, chunks[i].Message.Thinking)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatHarmonyParserStreamingSimple is a simpler test that just verifies basic streaming
|
||||
func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockResponses := []llm.CompletionResponse{
|
||||
{Content: "<|message|>First ", Done: false},
|
||||
{Content: "chunk ", Done: false},
|
||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
t.Logf("Mock received prompt: %q", r.Prompt)
|
||||
t.Logf("Mock sending %d responses", len(mockResponses))
|
||||
for i, resp := range mockResponses {
|
||||
t.Logf("Sending response %d: %q", i, resp.Content)
|
||||
fn(resp)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "gpt-oss",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "gpt-oss",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse chunks
|
||||
var chunks []api.ChatResponse
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
t.Logf("Received chunk %d: content=%q thinking=%q done=%v",
|
||||
len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
// Verify content
|
||||
var content strings.Builder
|
||||
for _, chunk := range chunks {
|
||||
content.WriteString(chunk.Message.Content)
|
||||
}
|
||||
|
||||
expectedContent := "First chunk here"
|
||||
if content.String() != expectedContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", content.String(), expectedContent)
|
||||
}
|
||||
|
||||
// Verify we got multiple chunks (streaming)
|
||||
contentChunks := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Message.Content != "" {
|
||||
contentChunks++
|
||||
}
|
||||
}
|
||||
|
||||
if contentChunks < 2 {
|
||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type expectedChunk struct {
|
||||
afterResponse int // Which mock response this chunk should appear after
|
||||
content string // Expected content in this chunk
|
||||
thinking string // Expected thinking in this chunk
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockResponses []llm.CompletionResponse
|
||||
expectedChunks []expectedChunk
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "simple message without thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||
{Content: "how can I help?", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 1, content: "Hello, "},
|
||||
{afterResponse: 2, content: "how can I help?"},
|
||||
},
|
||||
wantContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "message with analysis channel for thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||
{Content: "Let me think ", Done: false},
|
||||
{Content: "about this problem...", Done: false},
|
||||
{Content: "<|end|>", Done: false},
|
||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||
{Content: "The answer ", Done: false},
|
||||
{Content: "is 42", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 2, thinking: "Let me think "},
|
||||
{afterResponse: 3, thinking: "about this problem..."},
|
||||
{afterResponse: 6, content: "The answer "},
|
||||
{afterResponse: 7, content: "is 42"},
|
||||
},
|
||||
wantContent: "The answer is 42",
|
||||
wantThinking: "Let me think about this problem...",
|
||||
},
|
||||
{
|
||||
name: "streaming with partial tags across boundaries",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|chan", Done: false},
|
||||
{Content: "nel|>analy", Done: false},
|
||||
{Content: "sis<|mess", Done: false},
|
||||
{Content: "age|>Think", Done: false},
|
||||
{Content: "ing deeply...<|end|>", Done: false},
|
||||
{Content: "<|start|>assi", Done: false},
|
||||
{Content: "stant<|message|>Result ", Done: false},
|
||||
{Content: "computed<|e", Done: false},
|
||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 4, thinking: "Think"},
|
||||
{afterResponse: 5, thinking: "ing deeply..."},
|
||||
{afterResponse: 7, content: "Result "},
|
||||
{afterResponse: 8, content: "computed"},
|
||||
},
|
||||
wantContent: "Result computed",
|
||||
wantThinking: "Thinking deeply...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Channel to synchronize mock responses with chunk verification
|
||||
responsesSent := make(chan int, len(tc.mockResponses))
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Send mock responses one at a time, notifying when each is sent
|
||||
for i, resp := range tc.mockResponses {
|
||||
fn(resp)
|
||||
responsesSent <- i + 1
|
||||
}
|
||||
close(responsesSent)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a minimal model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse streaming response
|
||||
var chunks []api.ChatResponse
|
||||
var content, thinking strings.Builder
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
// Accumulate content and thinking from each chunk
|
||||
content.WriteString(chunk.Message.Content)
|
||||
thinking.WriteString(chunk.Message.Thinking)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got streaming chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
gotContent := content.String()
|
||||
gotThinking := thinking.String()
|
||||
|
||||
if gotContent != tc.wantContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||
}
|
||||
if gotThinking != tc.wantThinking {
|
||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||
}
|
||||
|
||||
// Verify last chunk has done=true
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Error("expected last chunk to have done=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
78
server/routes_list_test.go
Normal file
78
server/routes_list_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
expectNames := []string{
|
||||
"mistral:7b-instruct-q4_0",
|
||||
"zephyr:7b-beta-q5_K_M",
|
||||
"apple/OpenELM:latest",
|
||||
"boreas:2b-code-v1.5-q6_K",
|
||||
"notus:7b-v1-IQ2_S",
|
||||
// TODO: host:port currently fails on windows (#4107)
|
||||
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
|
||||
"mynamespace/apeliotes:latest",
|
||||
"myhost/mynamespace/lips:code",
|
||||
}
|
||||
|
||||
s := Server{modelCaches: &modelCaches{modelList: newModelListCache()}}
|
||||
s.modelCaches.modelList.Start(context.Background())
|
||||
if err := s.modelCaches.modelList.Wait(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, n := range expectNames {
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
|
||||
createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: n,
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ListHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ListResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(resp.Models) != len(expectNames) {
|
||||
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
|
||||
}
|
||||
|
||||
actualNames := make([]string, len(resp.Models))
|
||||
for i, m := range resp.Models {
|
||||
actualNames[i] = m.Name
|
||||
}
|
||||
|
||||
slices.Sort(actualNames)
|
||||
slices.Sort(expectNames)
|
||||
|
||||
if !slices.Equal(actualNames, expectNames) {
|
||||
t.Fatalf("expected slices to be equal %v", actualNames)
|
||||
}
|
||||
|
||||
for _, m := range resp.Models {
|
||||
if !slices.Contains(m.Capabilities, "completion") {
|
||||
t.Fatalf("capabilities for %q = %v, want completion", m.Name, m.Capabilities)
|
||||
}
|
||||
}
|
||||
}
|
||||
127
server/routes_options_test.go
Normal file
127
server/routes_options_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelOptionsNumCtxPriority(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envContextLen string // empty means not set (uses 0 sentinel)
|
||||
defaultNumCtx int // VRAM-based default
|
||||
modelNumCtx int // 0 means not set in model
|
||||
requestNumCtx int // 0 means not set in request
|
||||
expectedNumCtx int
|
||||
}{
|
||||
{
|
||||
name: "vram default when nothing else set",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 32768,
|
||||
},
|
||||
{
|
||||
name: "env var overrides vram default",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 8192,
|
||||
},
|
||||
{
|
||||
name: "model overrides vram default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 16384,
|
||||
},
|
||||
{
|
||||
name: "model overrides env var",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 16384,
|
||||
},
|
||||
{
|
||||
name: "request overrides everything",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "request overrides vram default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "request overrides model",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
modelNumCtx: 16384,
|
||||
requestNumCtx: 4096,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "low vram tier default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 4096,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 4096,
|
||||
},
|
||||
{
|
||||
name: "high vram tier default",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 262144,
|
||||
modelNumCtx: 0,
|
||||
requestNumCtx: 0,
|
||||
expectedNumCtx: 262144,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set or clear environment variable
|
||||
if tt.envContextLen != "" {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
|
||||
}
|
||||
|
||||
// Create server with VRAM-based default
|
||||
s := &Server{
|
||||
defaultNumCtx: tt.defaultNumCtx,
|
||||
}
|
||||
|
||||
// Create model options (use float64 as FromMap expects JSON-style numbers)
|
||||
var modelOpts map[string]any
|
||||
if tt.modelNumCtx != 0 {
|
||||
modelOpts = map[string]any{"num_ctx": float64(tt.modelNumCtx)}
|
||||
}
|
||||
model := &Model{
|
||||
Options: modelOpts,
|
||||
}
|
||||
|
||||
// Create request options (use float64 as FromMap expects JSON-style numbers)
|
||||
var requestOpts map[string]any
|
||||
if tt.requestNumCtx != 0 {
|
||||
requestOpts = map[string]any{"num_ctx": float64(tt.requestNumCtx)}
|
||||
}
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("modelOptions failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.NumCtx != tt.expectedNumCtx {
|
||||
t.Errorf("NumCtx = %d, want %d", opts.NumCtx, tt.expectedNumCtx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
128
server/routes_request_log_test.go
Normal file
128
server/routes_request_log_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logDir := t.TempDir()
|
||||
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||
|
||||
const route = "/v1/chat/completions"
|
||||
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||
|
||||
var bodySeenByHandler string
|
||||
|
||||
r := gin.New()
|
||||
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body in handler: %v", err)
|
||||
}
|
||||
|
||||
bodySeenByHandler = string(body)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||
req.Host = "127.0.0.1:11434"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if bodySeenByHandler != requestBody {
|
||||
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||
}
|
||||
|
||||
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob body logs: %v", err)
|
||||
}
|
||||
if len(bodyFiles) != 1 {
|
||||
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||
}
|
||||
|
||||
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob curl logs: %v", err)
|
||||
}
|
||||
if len(curlFiles) != 1 {
|
||||
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||
}
|
||||
|
||||
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body log: %v", err)
|
||||
}
|
||||
if string(bodyData) != requestBody {
|
||||
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||
}
|
||||
|
||||
curlData, err := os.ReadFile(curlFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read curl log: %v", err)
|
||||
}
|
||||
|
||||
curlString := string(curlData)
|
||||
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||
}
|
||||
|
||||
bodyFileName := filepath.Base(bodyFiles[0])
|
||||
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error creating request logger: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(requestLogger.dir)
|
||||
})
|
||||
|
||||
if requestLogger == nil || requestLogger.dir == "" {
|
||||
t.Fatalf("expected request logger directory to be set")
|
||||
}
|
||||
|
||||
info, err := os.Stat(requestLogger.dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected directory to exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
route string
|
||||
want string
|
||||
}{
|
||||
{route: "/api/generate", want: "api_generate"},
|
||||
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||
{route: "/v1/messages", want: "v1_messages"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
1159
server/routes_test.go
Normal file
1159
server/routes_test.go
Normal file
File diff suppressed because it is too large
Load Diff
340
server/routes_web_experimental_test.go
Normal file
340
server/routes_web_experimental_test.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
type webExperimentalUpstreamCapture struct {
|
||||
path string
|
||||
body string
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
|
||||
t.Helper()
|
||||
|
||||
capture := &webExperimentalUpstreamCapture{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, _ := io.ReadAll(r.Body)
|
||||
capture.path = r.URL.Path
|
||||
capture.body = string(payload)
|
||||
capture.header = r.Header.Clone()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(responseBody))
|
||||
}))
|
||||
|
||||
return srv, capture
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localPath string
|
||||
upstreamPath string
|
||||
requestBody string
|
||||
responseBody string
|
||||
assertBody string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
localPath: "/api/experimental/web_search",
|
||||
upstreamPath: "/api/web_search",
|
||||
requestBody: `{"query":"what is ollama?","max_results":3}`,
|
||||
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
|
||||
assertBody: `"query":"what is ollama?"`,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
localPath: "/api/experimental/web_fetch",
|
||||
upstreamPath: "/api/web_fetch",
|
||||
requestBody: `{"url":"https://ollama.com"}`,
|
||||
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
|
||||
assertBody: `"url":"https://ollama.com"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
|
||||
defer upstream.Close()
|
||||
|
||||
original := cloudProxyBaseURL
|
||||
cloudProxyBaseURL = upstream.URL
|
||||
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer should-forward")
|
||||
req.Header.Set("X-Test-Header", "web-experimental")
|
||||
req.Header.Set(cloudProxyClientVersionHeader, "should-be-overwritten")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if capture.path != tt.upstreamPath {
|
||||
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
|
||||
}
|
||||
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
|
||||
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
|
||||
}
|
||||
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
|
||||
t.Fatalf("expected forwarded Authorization header, got %q", got)
|
||||
}
|
||||
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
|
||||
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
|
||||
}
|
||||
if got := capture.header.Get(cloudProxyClientVersionHeader); got != version.Version {
|
||||
t.Fatalf("expected %s=%q, got %q", cloudProxyClientVersionHeader, version.Version, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []string{
|
||||
"/api/experimental/web_search",
|
||||
"/api/experimental/web_fetch",
|
||||
}
|
||||
|
||||
for _, path := range tests {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if string(body) != `{"error":"missing request body"}` {
|
||||
t.Fatalf("unexpected response body: %s", string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
request string
|
||||
operation string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
path: "/api/experimental/web_search",
|
||||
request: `{"query":"latest ollama release"}`,
|
||||
operation: cloudErrWebSearchUnavailable,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
path: "/api/experimental/web_fetch",
|
||||
request: `{"url":"https://ollama.com"}`,
|
||||
operation: cloudErrWebFetchUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != internalcloud.DisabledError(tt.operation) {
|
||||
t.Fatalf("unexpected error message: %q", got["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "https://ollama.com/signin/example", nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if got["signin_url"] != "https://ollama.com/signin/example" {
|
||||
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "", errors.New("key missing")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if _, ok := got["signin_url"]; ok {
|
||||
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
932
server/sched.go
Normal file
932
server/sched.go
Normal file
@@ -0,0 +1,932 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
ctx context.Context //nolint:containedctx
|
||||
model *Model
|
||||
opts api.Options
|
||||
sessionDuration *api.Duration
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
pendingReqCh chan *LlmRequest
|
||||
finishedReqCh chan *LlmRequest
|
||||
expiredCh chan *runnerRef
|
||||
unloadedCh chan any
|
||||
|
||||
// loadedMu protects loaded and activeLoading
|
||||
loadedMu sync.Mutex
|
||||
|
||||
// activeLoading is the model that we are currently working on loading,
|
||||
// including by evicting one or more other models. We can only load
|
||||
// one model at a time but new requests to models that already loaded can
|
||||
// happen in parallel
|
||||
activeLoading llm.LlamaServer
|
||||
loaded map[string]*runnerRef
|
||||
|
||||
loadFn func(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
|
||||
newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo
|
||||
getSystemInfoFn func() ml.SystemInfo
|
||||
waitForRecovery time.Duration
|
||||
}
|
||||
|
||||
// Default automatic value for number of models we allow per GPU
|
||||
// Model will still need to fit in VRAM, but loading many small models
|
||||
// on a large GPU can cause stalling
|
||||
var defaultModelsPerGPU = 3
|
||||
|
||||
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
|
||||
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
maxQueue := envconfig.MaxQueue()
|
||||
sched := &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan any, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: discover.GPUDevices,
|
||||
getSystemInfoFn: discover.GetSystemInfo,
|
||||
waitForRecovery: 5 * time.Second,
|
||||
}
|
||||
sched.loadFn = sched.load
|
||||
return sched
|
||||
}
|
||||
|
||||
// schedulerModelKey returns the scheduler map key for a model.
|
||||
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||
// ModelPath use manifest digest so distinct models don't collide.
|
||||
func schedulerModelKey(m *Model) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ModelPath != "" {
|
||||
return m.ModelPath
|
||||
}
|
||||
if m.Digest != "" {
|
||||
return "digest:" + m.Digest
|
||||
}
|
||||
if m.Name != "" {
|
||||
return "name:" + m.Name
|
||||
}
|
||||
if m.ShortName != "" {
|
||||
return "short:" + m.ShortName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
if m.CheckCapabilities(model.CapabilityVision) == nil {
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: m,
|
||||
opts: opts,
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[key]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
} else {
|
||||
select {
|
||||
case s.pendingReqCh <- req:
|
||||
default:
|
||||
req.errCh <- ErrMaxQueue
|
||||
}
|
||||
}
|
||||
return req.successCh, req.errCh
|
||||
}
|
||||
|
||||
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
|
||||
func (s *Scheduler) Run(ctx context.Context) {
|
||||
slog.Debug("starting llm scheduler")
|
||||
go func() {
|
||||
s.processPending(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
s.processCompleted(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Scheduler) processPending(ctx context.Context) {
|
||||
maxRunners := envconfig.MaxRunners()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("shutting down scheduler pending loop")
|
||||
return
|
||||
case pending := <-s.pendingReqCh:
|
||||
// Block other requests until we get this pending request running
|
||||
pending.schedAttempts++
|
||||
|
||||
if pending.ctx.Err() != nil {
|
||||
slog.Debug("pending request cancelled or timed out, skipping scheduling")
|
||||
continue
|
||||
}
|
||||
logutil.Trace("processing incoming request", "model", pending.model.ModelPath)
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
pendingKey := schedulerModelKey(pending.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pendingKey]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
if runner != nil {
|
||||
if runner.needsReload(ctx, pending) {
|
||||
slog.Debug("reloading", "runner", runner)
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
} else if maxRunners > 0 && loadedCount >= int(maxRunners) {
|
||||
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
} else {
|
||||
// Either no models are loaded or below envconfig.MaxRunners
|
||||
// Get a refreshed GPU list
|
||||
var gpus []ml.DeviceInfo
|
||||
if pending.opts.NumGPU == 0 {
|
||||
gpus = []ml.DeviceInfo{}
|
||||
} else {
|
||||
logutil.Trace("refreshing GPU list", "model", pending.model.ModelPath)
|
||||
gpus = s.getGpuFn(ctx, runnersSnapshot)
|
||||
}
|
||||
logutil.Trace("refreshing system information", "model", pending.model.ModelPath)
|
||||
systemInfo := s.getSystemInfoFn()
|
||||
if maxRunners <= 0 {
|
||||
// No user specified MaxRunners, so figure out what automatic setting to use for the next load attempt
|
||||
if pending.opts.NumGPU == 0 {
|
||||
// Need to get actual GPU list to set the correct default max models
|
||||
logutil.Trace("refreshing GPU list", "model", pending.model.ModelPath)
|
||||
g := s.getGpuFn(ctx, runnersSnapshot)
|
||||
maxRunners = uint(defaultModelsPerGPU * max(len(g), 1))
|
||||
} else {
|
||||
maxRunners = uint(defaultModelsPerGPU * max(len(gpus), 1))
|
||||
}
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Update free memory from currently loaded models
|
||||
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
if loadedCount == 0 {
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
s.loadFn(pending, systemInfo, gpus, false)
|
||||
break
|
||||
}
|
||||
|
||||
// More than one loaded model, so we have to see if the
|
||||
// new one fits
|
||||
logutil.Trace("loading additional model", "model", pending.model.ModelPath)
|
||||
needEvict := s.loadFn(pending, systemInfo, gpus, true)
|
||||
if !needEvict {
|
||||
slog.Debug("new model fits with existing models, loading")
|
||||
break
|
||||
}
|
||||
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
|
||||
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
|
||||
slog.Debug("runner to expire was nil, retrying")
|
||||
continue
|
||||
}
|
||||
// Trigger an expiration to unload once it's done
|
||||
runnerToExpire.refMu.Lock()
|
||||
slog.Debug("resetting model to expire immediately to make room", "runner", runnerToExpire, "refCount", runnerToExpire.refCount)
|
||||
if runnerToExpire.expireTimer != nil {
|
||||
runnerToExpire.expireTimer.Stop()
|
||||
runnerToExpire.expireTimer = nil
|
||||
}
|
||||
runnerToExpire.sessionDuration = 0
|
||||
if runnerToExpire.refCount <= 0 {
|
||||
s.expiredCh <- runnerToExpire
|
||||
}
|
||||
runnerToExpire.refMu.Unlock()
|
||||
// Wait for the unload to happen
|
||||
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("shutting down scheduler pending loop")
|
||||
return
|
||||
case <-s.unloadedCh:
|
||||
slog.Debug("unload completed", "runner", runnerToExpire)
|
||||
continue
|
||||
}
|
||||
}
|
||||
case <-s.unloadedCh:
|
||||
// An unload request when there are no pending request can be ignored
|
||||
slog.Debug("ignoring unload event with no pending requests")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
// Process completed requests, expired timers, and unloading models
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("shutting down scheduler completed loop")
|
||||
return
|
||||
case finished := <-s.finishedReqCh:
|
||||
finishedKey := schedulerModelKey(finished.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[finishedKey]
|
||||
s.loadedMu.Unlock()
|
||||
if runner == nil {
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||
continue
|
||||
}
|
||||
runner.refMu.Lock()
|
||||
runner.refCount--
|
||||
if runner.refCount <= 0 {
|
||||
if runner.sessionDuration <= 0 {
|
||||
slog.Debug("runner with zero duration has gone idle, expiring to unload", "runner", runner)
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
s.expiredCh <- runner
|
||||
} else if runner.expireTimer == nil {
|
||||
slog.Debug("runner with non-zero duration has gone idle, adding timer", "runner", runner, "duration", runner.sessionDuration)
|
||||
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
|
||||
slog.Debug("timer expired, expiring to unload", "runner", runner)
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
runner.expiresAt = time.Now().Add(runner.sessionDuration)
|
||||
} else {
|
||||
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "runner", runner, "duration", runner.sessionDuration)
|
||||
runner.expireTimer.Reset(runner.sessionDuration)
|
||||
runner.expiresAt = time.Now().Add(runner.sessionDuration)
|
||||
}
|
||||
}
|
||||
slog.Debug("after processing request finished event", "runner", runner, "refCount", runner.refCount)
|
||||
runner.refMu.Unlock()
|
||||
case runner := <-s.expiredCh:
|
||||
slog.Debug("runner expired event received", "runner", runner)
|
||||
runner.refMu.Lock()
|
||||
if runner.refCount > 0 {
|
||||
slog.Debug("expired event with positive ref count, retrying", "runner", runner, "refCount", runner.refCount)
|
||||
go func(runner *runnerRef) {
|
||||
// We can't unload yet, but want to as soon as the current request completes
|
||||
// So queue up another expired event
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
s.expiredCh <- runner
|
||||
}(runner)
|
||||
runner.refMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelKey]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
// request is canceled and we're trying to load another model
|
||||
// that requires this one to be evicted, or the settings change
|
||||
// and require a reload
|
||||
s.loadedMu.Unlock()
|
||||
runner.refMu.Unlock()
|
||||
slog.Debug("duplicate expired event, ignoring", "runner", runner)
|
||||
} else if runner.pid != runnerToUnload.pid {
|
||||
// If the pids do not match, we likely had multiple load
|
||||
// failures for the same model in quick succession due to
|
||||
// request context canceled and are draining the queue of
|
||||
// events. Ensure the orphaned runner is properly shut down, but
|
||||
// do not delete the mismatched loaded runner, or wait for VRAM
|
||||
// convergence.
|
||||
slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload)
|
||||
runner.unload()
|
||||
s.loadedMu.Unlock()
|
||||
runner.refMu.Unlock()
|
||||
} else {
|
||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelKey)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
runner.refMu.Unlock()
|
||||
slog.Debug("sending an unloaded event", "runner", runner)
|
||||
s.unloadedCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Complete the pending request and send the runner back to the requester
|
||||
// Wires up a finished event after the request context is completed
|
||||
// Updates session duration, and resets expiration timer
|
||||
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
runner.refCount++
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
if pending.sessionDuration != nil {
|
||||
runner.sessionDuration = pending.sessionDuration.Duration
|
||||
}
|
||||
pending.successCh <- runner
|
||||
go func() {
|
||||
<-pending.ctx.Done()
|
||||
slog.Debug("context for request finished", "runner", runner)
|
||||
finished <- pending
|
||||
}()
|
||||
}
|
||||
|
||||
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
||||
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
||||
func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
|
||||
numParallel := max(int(envconfig.NumParallel()), 1)
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe", "nemotron_h_omni"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
llama := s.activeLoading
|
||||
|
||||
if llama == nil {
|
||||
var err error
|
||||
if !req.model.IsMLX() {
|
||||
f, loadErr := llm.LoadModel(req.model.ModelPath, 1024)
|
||||
if loadErr != nil {
|
||||
slog.Info("failed to load model metadata", "model", req.model.ModelPath, "error", loadErr)
|
||||
req.errCh <- loadErr
|
||||
s.loadedMu.Unlock()
|
||||
return false
|
||||
}
|
||||
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
// check for model compatibility
|
||||
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
modelName := req.model.ShortName
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
llama, err = imagegen.NewServer(modelName)
|
||||
} else {
|
||||
llama, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
slog.Info("failed to create server", "model", req.model.ShortName, "error", err)
|
||||
req.errCh <- err
|
||||
s.loadedMu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
s.activeLoading = llama
|
||||
} else {
|
||||
wantPath := req.model.ModelPath
|
||||
if wantPath == "" {
|
||||
wantPath = req.model.ShortName
|
||||
}
|
||||
if s.activeLoading.ModelPath() != wantPath {
|
||||
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), wantPath))
|
||||
}
|
||||
}
|
||||
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
systemTotalMemory := systemInfo.TotalMemory
|
||||
systemFreeMemory := systemInfo.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
|
||||
"available", format.HumanBytes2(available),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
}
|
||||
|
||||
gpuIDs, err := llama.Load(req.ctx, systemInfo, gpus, requireFull)
|
||||
if err != nil {
|
||||
if errors.Is(err, llm.ErrLoadRequiredFull) {
|
||||
if !requireFull {
|
||||
// No other models loaded, yet we still don't fit, so report an error
|
||||
slog.Info("model is too large for system memory", "requireFull", requireFull)
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
req.errCh <- err
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
slog.Info("Load failed", "model", req.model.ModelPath, "error", err)
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
req.errCh <- err
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine if we have discrete GPUs which we should monitor VRAM usage on during shutdown
|
||||
discreteGPUs := false
|
||||
iGPUScan:
|
||||
for _, devid := range gpuIDs {
|
||||
for _, dev := range gpus {
|
||||
if dev.DeviceID == devid {
|
||||
if !dev.Integrated {
|
||||
discreteGPUs = true
|
||||
break iGPUScan
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
totalSize, vramSize := llama.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
runner.numParallel = numParallel
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
oldRunner.unload()
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[runner.modelKey] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer runner.refMu.Unlock()
|
||||
if err = llama.WaitUntilRunning(req.ctx); err != nil {
|
||||
slog.Error("error loading llama server", "error", err)
|
||||
req.errCh <- err
|
||||
slog.Debug("triggering expiration for failed load", "runner", runner)
|
||||
s.expiredCh <- runner
|
||||
return
|
||||
}
|
||||
slog.Debug("finished setting up", "runner", runner)
|
||||
if runner.pid < 0 {
|
||||
runner.pid = llama.Pid()
|
||||
}
|
||||
runner.refCount++
|
||||
runner.loading = false
|
||||
go func() {
|
||||
<-req.ctx.Done()
|
||||
slog.Debug("context for request finished")
|
||||
s.finishedReqCh <- req
|
||||
}()
|
||||
req.successCh <- runner
|
||||
}()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
|
||||
if len(allGpus) == 0 {
|
||||
return
|
||||
}
|
||||
predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||
s.loadedMu.Lock()
|
||||
runners := make([]*runnerRef, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runners = append(runners, r)
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
for _, r := range runners {
|
||||
r.refMu.Lock()
|
||||
if r.llama != nil {
|
||||
for _, gpu := range allGpus {
|
||||
predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
|
||||
}
|
||||
r.refMu.Unlock()
|
||||
}
|
||||
|
||||
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
|
||||
for i := range allGpus {
|
||||
if p, ok := predMap[allGpus[i].DeviceID]; ok {
|
||||
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
|
||||
if p > allGpus[i].TotalMemory {
|
||||
// Shouldn't happen
|
||||
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
|
||||
allGpus[i].FreeMemory = 0
|
||||
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
|
||||
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
|
||||
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
|
||||
// after we start our first runner, then we'll never account for that, so picking the smallest free value seems prudent.
|
||||
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
|
||||
}
|
||||
slog.Info("updated VRAM based on existing loaded models", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO consolidate sched_types.go
|
||||
type runnerRef struct {
|
||||
refMu sync.Mutex
|
||||
refCount uint // prevent unloading if > 0
|
||||
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
|
||||
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
sessionDuration time.Duration
|
||||
expireTimer *time.Timer
|
||||
expiresAt time.Time
|
||||
|
||||
model *Model
|
||||
modelPath string
|
||||
modelKey string
|
||||
numParallel int
|
||||
*api.Options
|
||||
}
|
||||
|
||||
// The refMu must already be held when calling unload
|
||||
func (runner *runnerRef) unload() {
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
if runner.llama != nil {
|
||||
runner.llama.Close()
|
||||
}
|
||||
runner.model = nil
|
||||
runner.Options = nil
|
||||
runner.gpus = nil
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
|
||||
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
|
||||
timeout := 10 * time.Second
|
||||
if runner.loading {
|
||||
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||
}
|
||||
|
||||
if runner.Options == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Don't reload runner if num_gpu=-1 was provided
|
||||
optsExisting := runner.Options.Runner
|
||||
optsNew := req.opts.Runner
|
||||
if optsNew.NumGPU < 0 {
|
||||
optsExisting.NumGPU = -1
|
||||
optsNew.NumGPU = -1
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Free memory reporting on GPUs can lag for a while even after the runner
|
||||
// exits, so we have to keep checking until we see the available memory recover,
|
||||
// otherwise subsequent model loads will get far less layers loaded or worse
|
||||
// case, may completely fall back to CPU mode.
|
||||
// This routine must be called before the runner unloads so it can establish
|
||||
// a before and after GPU memory allocation. The returned channel
|
||||
// will be notified when we're done waiting, or have timed out and should
|
||||
// proceed anyway
|
||||
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []ml.FilteredRunnerDiscovery) chan any {
|
||||
finished := make(chan any, 1)
|
||||
|
||||
// CPU, Metal and iGPUs don't need checking, so no waiting required
|
||||
if len(runner.gpus) == 0 || !runner.discreteGPUs ||
|
||||
(len(runner.gpus) == 1 && runner.gpus[0].Library == "Metal") {
|
||||
finished <- struct{}{}
|
||||
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
|
||||
return finished
|
||||
}
|
||||
start := time.Now()
|
||||
|
||||
// Establish a baseline before we unload
|
||||
gpusBefore := s.getGpuFn(context.Background(), runners)
|
||||
var totalMemoryBefore, freeMemoryBefore uint64
|
||||
for _, gpu := range gpusBefore {
|
||||
totalMemoryBefore += gpu.TotalMemory
|
||||
freeMemoryBefore += gpu.FreeMemory
|
||||
}
|
||||
totalMemoryNow := totalMemoryBefore
|
||||
freeMemoryNow := freeMemoryBefore
|
||||
|
||||
go func() {
|
||||
// typical convergence is 0.5-1.5s - If it takes too long to discover and converge, let the scheduler estimate VRAM usage
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.waitForRecovery)
|
||||
defer cancel()
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := s.getGpuFn(ctx, runners)
|
||||
totalMemoryNow = 0
|
||||
freeMemoryNow = 0
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
if freeMemoryNow > freeMemoryBefore {
|
||||
logutil.Trace("gpu VRAM convergence", "percent", int(float32(freeMemoryNow-freeMemoryBefore)/float32(runner.vramSize)*100))
|
||||
} else {
|
||||
logutil.Trace("gpu VRAM convergence", "percent", 0)
|
||||
}
|
||||
// If we're within ~75% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return finished
|
||||
}
|
||||
|
||||
func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
modelID := runner.modelPath
|
||||
if modelID == "" {
|
||||
modelID = runner.modelKey
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
}
|
||||
if len(runner.gpus) > 0 {
|
||||
attrs = append(attrs,
|
||||
slog.Any("inference", runner.gpus),
|
||||
)
|
||||
}
|
||||
attrs = append(attrs,
|
||||
slog.String("size", format.HumanBytes2(runner.totalSize)),
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", modelID),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Implements discover.RunnerDiscovery
|
||||
func (runner *runnerRef) GetPort() int {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.GetPort()
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.GetDeviceInfos(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID {
|
||||
return runner.gpus
|
||||
}
|
||||
|
||||
func (runner *runnerRef) HasExited() bool {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.HasExited()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ByDurationAndName []*runnerRef
|
||||
|
||||
func (a ByDurationAndName) Len() int { return len(a) }
|
||||
func (a ByDurationAndName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByDurationAndName) Less(i, j int) bool {
|
||||
// Primary sort by session duration (uint64 to handle negatives)
|
||||
d1 := uint64(a[i].sessionDuration)
|
||||
d2 := uint64(a[j].sessionDuration)
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model key/path lex order
|
||||
n1 := a[i].modelPath
|
||||
if n1 == "" {
|
||||
n1 = a[i].modelKey
|
||||
}
|
||||
n2 := a[j].modelPath
|
||||
if n2 == "" {
|
||||
n2 = a[j].modelKey
|
||||
}
|
||||
return n1 < n2
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
// type BySize []*runnerRef
|
||||
// func (a BySize) Len() int { return len(a) }
|
||||
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
// func (a BySize) Less(i, j int) bool { return a[i].vramSize < a[j].vramSize }
|
||||
|
||||
// findRunnerToUnload finds a runner to unload to make room for a new model
|
||||
func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
||||
s.loadedMu.Lock()
|
||||
runnerList := make([]*runnerRef, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnerList = append(runnerList, r)
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
if len(runnerList) == 0 {
|
||||
slog.Debug("no loaded runner to unload")
|
||||
return nil
|
||||
}
|
||||
|
||||
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
|
||||
// e.g., if we have multiple options, will one make room for the request?
|
||||
sort.Sort(ByDurationAndName(runnerList))
|
||||
|
||||
// First try to find a runner that's already idle
|
||||
for _, runner := range runnerList {
|
||||
runner.refMu.Lock()
|
||||
rc := runner.refCount
|
||||
runner.refMu.Unlock()
|
||||
if rc == 0 {
|
||||
slog.Debug("found an idle runner to unload", "runner", runner)
|
||||
return runner
|
||||
}
|
||||
}
|
||||
// None appear idle, just wait for the one with the shortest duration
|
||||
slog.Debug("no idle runners, picking the shortest duration", "runner_count", len(runnerList), "runner", runnerList[0])
|
||||
return runnerList[0]
|
||||
}
|
||||
|
||||
func (s *Scheduler) unloadAllRunners() {
|
||||
s.loadedMu.Lock()
|
||||
defer s.loadedMu.Unlock()
|
||||
|
||||
if s.activeLoading != nil {
|
||||
slog.Debug("shutting down currently loading runner")
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
}
|
||||
|
||||
for model, runner := range s.loaded {
|
||||
if runner.llama != nil {
|
||||
slog.Debug("shutting down runner", "model", model)
|
||||
runner.llama.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
modelKey := schedulerModelKey(model)
|
||||
s.loadedMu.Lock()
|
||||
runner, ok := s.loaded[modelKey]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
runner.expiresAt = time.Now()
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
runner.sessionDuration = 0
|
||||
if runner.refCount <= 0 {
|
||||
s.expiredCh <- runner
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
}
|
||||
}
|
||||
1018
server/sched_test.go
Normal file
1018
server/sched_test.go
Normal file
File diff suppressed because it is too large
Load Diff
8
server/sparse_common.go
Normal file
8
server/sparse_common.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import "os"
|
||||
|
||||
func setSparse(*os.File) {
|
||||
}
|
||||
17
server/sparse_windows.go
Normal file
17
server/sparse_windows.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setSparse(file *os.File) {
|
||||
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||
windows.DeviceIoControl( //nolint:errcheck
|
||||
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
nil, nil,
|
||||
)
|
||||
}
|
||||
15
server/test_home_test.go
Normal file
15
server/test_home_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
t.Setenv("OLLAMA_MODELS", "")
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
405
server/upload.go
Normal file
405
server/upload.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
manifest.Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
|
||||
Parts []blobUploadPart
|
||||
|
||||
nextURL chan *url.URL
|
||||
|
||||
context.CancelFunc
|
||||
|
||||
file *os.File
|
||||
|
||||
done bool
|
||||
err error
|
||||
references atomic.Int32
|
||||
}
|
||||
|
||||
const (
|
||||
numUploadParts = 16
|
||||
minUploadPartSize int64 = 100 * format.MegaByte
|
||||
maxUploadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
location := resp.Header.Get("Docker-Upload-Location")
|
||||
if location == "" {
|
||||
location = resp.Header.Get("Location")
|
||||
}
|
||||
|
||||
fi, err := os.Stat(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Total = fi.Size()
|
||||
|
||||
// http.StatusCreated indicates a blob has been mounted
|
||||
// ref: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
|
||||
if resp.StatusCode == http.StatusCreated {
|
||||
b.Completed.Store(b.Total)
|
||||
b.done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
size := b.Total / numUploadParts
|
||||
switch {
|
||||
case size < minUploadPartSize:
|
||||
size = minUploadPartSize
|
||||
case size > maxUploadPartSize:
|
||||
size = maxUploadPartSize
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < fi.Size() {
|
||||
if offset+size > fi.Size() {
|
||||
size = fi.Size() - offset
|
||||
}
|
||||
|
||||
// set part.N to the current number of parts
|
||||
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
|
||||
offset += size
|
||||
}
|
||||
|
||||
if len(b.Parts) > 0 {
|
||||
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
|
||||
}
|
||||
|
||||
requestURL, err = url.Parse(location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.nextURL = make(chan *url.URL, 1)
|
||||
b.nextURL <- requestURL
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
b.file, err = os.Open(p)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
defer b.file.Close()
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numUploadParts)
|
||||
for i := range b.Parts {
|
||||
part := &b.Parts[i]
|
||||
select {
|
||||
case <-inner.Done():
|
||||
case requestURL := <-b.nextURL:
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
for try := range maxRetries {
|
||||
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
case errors.Is(err, errMaxRetriesExceeded):
|
||||
return err
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
requestURL := <-b.nextURL
|
||||
|
||||
// calculate md5 checksum and add it to the commit request
|
||||
md5sum := md5.New()
|
||||
for _, part := range b.Parts {
|
||||
md5sum.Write(part.Sum(nil))
|
||||
}
|
||||
|
||||
values := requestURL.Query()
|
||||
values.Add("digest", b.Digest)
|
||||
values.Add("etag", fmt.Sprintf("%x-%d", md5sum.Sum(nil), len(b.Parts)))
|
||||
requestURL.RawQuery = values.Encode()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", "0")
|
||||
|
||||
for try := range maxRetries {
|
||||
var resp *http.Response
|
||||
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
break
|
||||
} else if err != nil {
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
break
|
||||
}
|
||||
|
||||
b.err = err
|
||||
b.done = true
|
||||
}
|
||||
|
||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", strconv.FormatInt(part.Size, 10))
|
||||
|
||||
if method == http.MethodPatch {
|
||||
headers.Set("X-Redirect-Uploads", "1")
|
||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(b.file, part.Offset, part.Size)
|
||||
|
||||
md5sum := md5.New()
|
||||
w := &progressWriter{blobUpload: b}
|
||||
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
|
||||
if err != nil {
|
||||
w.Rollback()
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
location := resp.Header.Get("Docker-Upload-Location")
|
||||
if location == "" {
|
||||
location = resp.Header.Get("Location")
|
||||
}
|
||||
|
||||
nextURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
w.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusTemporaryRedirect:
|
||||
w.Rollback()
|
||||
b.nextURL <- nextURL
|
||||
|
||||
redirectURL, err := resp.Location()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// retry uploading to the redirect URL
|
||||
for try := range maxRetries {
|
||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, ®istryOptions{})
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
case errors.Is(err, errMaxRetriesExceeded):
|
||||
return err
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Token = token
|
||||
fallthrough
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
w.Rollback()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("http status %s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
if method == http.MethodPatch {
|
||||
b.nextURL <- nextURL
|
||||
}
|
||||
|
||||
part.Hash = md5sum
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobUpload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
|
||||
func (b *blobUpload) release() {
|
||||
if b.references.Add(-1) == 0 {
|
||||
b.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
|
||||
b.acquire()
|
||||
defer b.release()
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pushing %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type blobUploadPart struct {
|
||||
// N is the part number
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
hash.Hash
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
written int64
|
||||
*blobUpload
|
||||
}
|
||||
|
||||
func (p *progressWriter) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.written += int64(n)
|
||||
p.Completed.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *progressWriter) Rollback() {
|
||||
p.Completed.Add(-p.written)
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
case err != nil:
|
||||
return err
|
||||
default:
|
||||
defer resp.Body.Close()
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: layer.Size,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
go upload.Run(context.Background(), opts)
|
||||
}
|
||||
|
||||
return upload.Wait(ctx, fn)
|
||||
}
|
||||
Reference in New Issue
Block a user