ollama source for Momentry Core verification
This commit is contained in:
1197
server/internal/client/ollama/registry.go
Normal file
1197
server/internal/client/ollama/registry.go
Normal file
File diff suppressed because it is too large
Load Diff
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// TODO: go:build goexperiment.synctest
|
||||
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPullDownloadTimeout(t *testing.T) {
|
||||
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer t.Log("upstream", r.Method, r.URL.Path)
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
|
||||
io.WriteString(w, `{
|
||||
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
|
||||
}`)
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
|
||||
// Get headers out to client and then hang on the response
|
||||
w.WriteHeader(200)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
// Hang on the response and unblock when the client
|
||||
// gives up
|
||||
<-r.Context().Done()
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s", r.URL.Path)
|
||||
}
|
||||
})
|
||||
rc.ReadTimeout = 100 * time.Millisecond
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- rc.Pull(ctx, "http://example.com/library/smol")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
want := context.DeadlineExceeded
|
||||
if !errors.Is(err, want) {
|
||||
t.Errorf("err = %v, want %v", err, want)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for Pull to finish")
|
||||
}
|
||||
}
|
||||
953
server/internal/client/ollama/registry_test.go
Normal file
953
server/internal/client/ollama/registry_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func ExampleRegistry_cancelOnFirstError() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if err != nil {
|
||||
// Discontinue pulling layers if there is an
|
||||
// error instead of continuing to pull more
|
||||
// data.
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
var r Registry
|
||||
if err := r.Pull(ctx, "model"); err != nil {
|
||||
// panic for demo purposes
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestMarshalJSON(t *testing.T) {
|
||||
// All manifests should contain an "empty" config object.
|
||||
var m Manifest
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
|
||||
t.Error("expected manifest to contain empty config")
|
||||
t.Fatalf("got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
var errRoundTrip = errors.New("forced roundtrip error")
|
||||
|
||||
type recordRoundTripper http.HandlerFunc
|
||||
|
||||
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
rr(w, req)
|
||||
if w.Code == 499 {
|
||||
return nil, errRoundTrip
|
||||
}
|
||||
resp := w.Result()
|
||||
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
|
||||
// set it manually.
|
||||
resp.Request = req
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
||||
//
|
||||
// empty: no data
|
||||
// zero: no layers
|
||||
// single: one layer with the contents "exists"
|
||||
// multiple: two layers with the contents "exists" and "here"
|
||||
// notfound: a layer that does not exist in the cache
|
||||
// null: one null layer (e.g. [null])
|
||||
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
||||
// invalid: a layer with invalid JSON data
|
||||
//
|
||||
// Tests that want to ensure the client does not communicate with the upstream
|
||||
// registry should pass a nil handler, which will cause a panic if
|
||||
// communication is attempted.
|
||||
//
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mklayer := func(data string) *Layer {
|
||||
return &Layer{
|
||||
Digest: importBytes(t, c, data),
|
||||
Size: int64(len(data)),
|
||||
}
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(upstreamRegistry),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
commit := func(name string, layers ...*Layer) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(&Manifest{Layers: layers})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link(name, string(data))
|
||||
}
|
||||
|
||||
link("empty", "")
|
||||
commit("zero")
|
||||
commit("single", mklayer("exists"))
|
||||
commit("multiple", mklayer("exists"), mklayer("present"))
|
||||
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
||||
commit("null", nil)
|
||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||
link("invalid", "!!!!!")
|
||||
|
||||
return r, c
|
||||
}
|
||||
|
||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func checkErrCode(t *testing.T, err error, status int, code string) {
|
||||
t.Helper()
|
||||
var e *Error
|
||||
if !errors.As(err, &e) || e.status != status || e.Code != code {
|
||||
t.Errorf("err = %v; want %v %v", err, status, code)
|
||||
}
|
||||
}
|
||||
|
||||
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
||||
d, err := c.Import(strings.NewReader(data), int64(len(data)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||
return WithTrace(ctx, t), t
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
pushed = true
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
io.Copy(io.Discard, r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent uploads
|
||||
|
||||
var errs []error
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(_ *Layer, n int64, err error) {
|
||||
// uploading one at a time so no need to lock
|
||||
errs = append(errs, err)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushUploadFileOpenError(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, _ int64, err error) {
|
||||
// Remove the file just before it is opened for upload,
|
||||
// but after the initial Stat that happens before the
|
||||
// upload starts
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"null",
|
||||
"!!!",
|
||||
`{"layers":[]}`,
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
exists := blob.DigestFromBytes("exists")
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
|
||||
w.WriteHeader(499) // should not hit manifest endpoint
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
||||
})
|
||||
|
||||
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestInsecureSkipVerify(t *testing.T) {
|
||||
exists := blob.DigestFromBytes("exists")
|
||||
|
||||
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
const name = "library/insecure"
|
||||
|
||||
var rc Registry
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
_, err := rc.Resolve(t.Context(), url)
|
||||
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
||||
t.Errorf("err = %v; want cert verification failure", err)
|
||||
}
|
||||
|
||||
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
||||
_, err = rc.Resolve(t.Context(), url)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
want *Error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors single",
|
||||
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "errors multiple",
|
||||
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "error empty",
|
||||
data: `{"error":""}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error very empty",
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
data: `{"error":"message", "code":"code"}`,
|
||||
want: &Error{Code: "code", Message: "message"},
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
data: `{"error": 1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Error
|
||||
err := json.Unmarshal([]byte(tt.data), &got)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Errorf("Unmarshal() error = %v", err)
|
||||
// fallthrough and check got
|
||||
}
|
||||
if tt.want == nil {
|
||||
tt.want = &Error{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, *tt.want) {
|
||||
t.Errorf("got = %v; want %v", got, *tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseNameExtendedErrors tests that parseName returns errors messages with enough
|
||||
// detail for users to debug naming issues they may encounter. Previous to this
|
||||
// test, the error messages were not very helpful and each problem was reported
|
||||
// as the same message.
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
// make a blob and link it
|
||||
d := blob.DigestFromBytes("{}")
|
||||
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||
check(err)
|
||||
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||
check(err)
|
||||
|
||||
// confirm linked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
check(err)
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
check(err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||
// with the checksum of its sha256 hash. The checksum is:
|
||||
//
|
||||
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||
//
|
||||
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||
// to be the most readable and maintainable approach. The sum is consistently
|
||||
// used in the tests and unique so searches do not yield false positives.
|
||||
|
||||
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
t.Helper()
|
||||
if got := req.URL.Path; got != path {
|
||||
t.Errorf("URL = %q, want %q", got, path)
|
||||
}
|
||||
if req.Method != method {
|
||||
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||
}
|
||||
}
|
||||
|
||||
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(upstream)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
},
|
||||
})
|
||||
|
||||
rc := &Registry{
|
||||
Cache: cache,
|
||||
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, s.Listener.Addr().String())
|
||||
},
|
||||
}},
|
||||
}
|
||||
return rc, ctx
|
||||
}
|
||||
|
||||
func TestPullChunked(t *testing.T) {
|
||||
var steps atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch steps.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
t.Logf("writing c")
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
testutil.Check(t, err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
testutil.Check(t, err)
|
||||
|
||||
if g := steps.Load(); g != 4 {
|
||||
t.Fatalf("got %d steps, want 4", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullCached(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
// Premeptively cache the blob
|
||||
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
check(err)
|
||||
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||
check(err)
|
||||
|
||||
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPullManifestError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var got *Error
|
||||
if !errors.Is(err, ErrModelNotFound) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `!`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var want *json.SyntaxError
|
||||
if !errors.As(err, &want) {
|
||||
t.Fatalf("err = %T, want %T", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerChecksumError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
var written atomic.Int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
var got *Error
|
||||
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||
t.Fatalf("err = %v, want %v", err, got)
|
||||
}
|
||||
|
||||
if g := written.Load(); g != 1 {
|
||||
t.Fatalf("wrote %d bytes, want 1", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullChunksumStreamError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
|
||||
// Write one valid chunksum and one invalid chunksum
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||
fmt.Fprint(w, "sha256:!") // invalid
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(got, ErrIncomplete) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = f.w.Write(p)
|
||||
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||
return
|
||||
}
|
||||
|
||||
func TestPullChunksumStreaming(t *testing.T) {
|
||||
csr, csw := io.Pipe()
|
||||
defer csw.Close()
|
||||
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||
_, err := io.Copy(fw, csr)
|
||||
if err != nil {
|
||||
t.Errorf("copy: %v", err)
|
||||
}
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
update := make(chan int64, 1)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if n > 0 {
|
||||
update <- n
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||
}()
|
||||
|
||||
// Send first chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
if g := <-update; g != 2 {
|
||||
t.Fatalf("got %d, want 2", g)
|
||||
}
|
||||
|
||||
// now send the second chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
if g := <-update; g != 3 {
|
||||
t.Fatalf("got %d, want 3", g)
|
||||
}
|
||||
csw.Close()
|
||||
testutil.Check(t, <-errc)
|
||||
}
|
||||
|
||||
func TestPullChunksumsCached(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1 // force serial processing of chunksums
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the pull after the first chunksum is processed, but before
|
||||
// the second chunksum is processed (which is waiting because
|
||||
// MaxStreams=1). This should cause the second chunksum to error out
|
||||
// leaving the blob incomplete.
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Reset state and pull again to ensure the blob chunks that should
|
||||
// have been cached are, and the remaining chunk was downloaded, making
|
||||
// the blob complete.
|
||||
step.Store(0)
|
||||
var written atomic.Int64
|
||||
var cached atomic.Int64
|
||||
ctx = WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if errors.Is(err, ErrCached) {
|
||||
cached.Add(n)
|
||||
}
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
check(err)
|
||||
|
||||
if g := written.Load(); g != 5 {
|
||||
t.Fatalf("wrote %d bytes, want 3", g)
|
||||
}
|
||||
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||
t.Fatalf("cached %d bytes, want 5", g)
|
||||
}
|
||||
}
|
||||
72
server/internal/client/ollama/trace.go
Normal file
72
server/internal/client/ollama/trace.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
//
|
||||
// The n argument is the number of bytes transferred so far, and err is
|
||||
// any error that has occurred. If n == 0, and err is nil, the download
|
||||
// or upload has just started. If err is [ErrCached], the download or
|
||||
// upload has been skipped because the blob is already present in the
|
||||
// local cache or remote registry, respectively. Otherwise, if err is
|
||||
// non-nil, the download or upload has failed. When l.Size == n, and
|
||||
// err is nil, the download or upload has completed.
|
||||
//
|
||||
// A function assigned must be safe for concurrent use. The function is
|
||||
// called synchronously and so should not block or take long to run.
|
||||
Update func(_ *Layer, n int64, _ error)
|
||||
}
|
||||
|
||||
func (t *Trace) update(l *Layer, n int64, err error) {
|
||||
if t.Update != nil {
|
||||
t.Update(l, n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type traceKey struct{}
|
||||
|
||||
// WithTrace adds a trace to the context for transfer progress reporting.
|
||||
func WithTrace(ctx context.Context, t *Trace) context.Context {
|
||||
old := traceFromContext(ctx)
|
||||
if old == t {
|
||||
// No change, return the original context. This also prevents
|
||||
// infinite recursion below, if the caller passes the same
|
||||
// Trace.
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Create a new Trace that wraps the old one, if any. If we used the
|
||||
// same pointer t, we end up with a recursive structure.
|
||||
composed := &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if old != nil {
|
||||
old.update(l, n, err)
|
||||
}
|
||||
t.update(l, n, err)
|
||||
},
|
||||
}
|
||||
return context.WithValue(ctx, traceKey{}, composed)
|
||||
}
|
||||
|
||||
var emptyTrace = &Trace{}
|
||||
|
||||
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
|
||||
// none is found.
|
||||
//
|
||||
// It never returns nil.
|
||||
func traceFromContext(ctx context.Context) *Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(*Trace)
|
||||
if t == nil {
|
||||
return emptyTrace
|
||||
}
|
||||
return t
|
||||
}
|
||||
Reference in New Issue
Block a user