Files
ollama/x/transfer/transfer_test.go
2026-05-22 17:19:10 +08:00

2910 lines
78 KiB
Go

package transfer
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// chunkedSession tracks accumulated PATCH body bytes for an upload session.
// Tests that mock the registry use it to handle the GGUF-style POST → PATCH →
// PUT-finalize flow without each test reimplementing the bookkeeping.
type chunkedSession struct {
mu sync.Mutex
sessions map[string]*bytes.Buffer
}
func newChunkedSession() *chunkedSession {
return &chunkedSession{sessions: make(map[string]*bytes.Buffer)}
}
// recordPatch reads the request body into the session buffer and writes a
// 202 Accepted response with Docker-Upload-Location pointing at the same
// session URL. Use this from a mock handler's PATCH branch.
func (c *chunkedSession) recordPatch(w http.ResponseWriter, r *http.Request) {
c.mu.Lock()
buf, ok := c.sessions[r.URL.Path]
if !ok {
buf = &bytes.Buffer{}
c.sessions[r.URL.Path] = buf
}
c.mu.Unlock()
io.Copy(buf, r.Body)
w.Header().Set("Docker-Upload-Location", r.URL.Path)
w.WriteHeader(http.StatusAccepted)
}
// finalize returns the bytes accumulated for the given session URL path.
// Use it from the PUT-finalize branch of a mock handler.
func (c *chunkedSession) finalize(sessionPath string) []byte {
c.mu.Lock()
defer c.mu.Unlock()
if buf, ok := c.sessions[sessionPath]; ok {
return buf.Bytes()
}
return nil
}
// createTestBlob creates a blob with deterministic content and returns its digest
func createTestBlob(t *testing.T, dir string, size int) (Blob, []byte) {
t.Helper()
// Create deterministic content
data := make([]byte, size)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
// Write to file
path := filepath.Join(dir, digestToPath(digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatal(err)
}
return Blob{Digest: digest, Size: int64(size)}, data
}
func TestDownload(t *testing.T) {
// Create test blobs on "server"
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1024)
blob2, data2 := createTestBlob(t, serverDir, 2048)
blob3, data3 := createTestBlob(t, serverDir, 512)
// Mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract digest from URL: /v2/library/_/blobs/sha256:...
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
// Download to client dir
clientDir := t.TempDir()
var progressCalls atomic.Int32
var lastCompleted, lastTotal atomic.Int64
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 2,
Progress: func(completed, total int64) {
progressCalls.Add(1)
lastCompleted.Store(completed)
lastTotal.Store(total)
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify files
verifyBlob(t, clientDir, blob1, data1)
verifyBlob(t, clientDir, blob2, data2)
verifyBlob(t, clientDir, blob3, data3)
// Verify progress was called
if progressCalls.Load() == 0 {
t.Error("Progress callback never called")
}
if lastTotal.Load() != blob1.Size+blob2.Size+blob3.Size {
t.Errorf("Wrong total: got %d, want %d", lastTotal.Load(), blob1.Size+blob2.Size+blob3.Size)
}
}
func TestDownloadWithRedirect(t *testing.T) {
// Create test blob on "CDN"
cdnDir := t.TempDir()
blob, data := createTestBlob(t, cdnDir, 1024)
// CDN server (the redirect target)
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Serve the blob content
digest := filepath.Base(r.URL.Path)
path := filepath.Join(cdnDir, digestToPath(digest))
blobData, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(blobData)))
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer cdn.Close()
// Registry server (redirects to CDN)
registry := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Redirect to CDN
cdnURL := cdn.URL + r.URL.Path
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
}))
defer registry.Close()
clientDir := t.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: registry.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download with redirect failed: %v", err)
}
verifyBlob(t, clientDir, blob, data)
}
func TestDownloadWithRetry(t *testing.T) {
// Create test blob
serverDir := t.TempDir()
blob, data := createTestBlob(t, serverDir, 1024)
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := requestCount.Add(1)
// Fail first 2 attempts, succeed on 3rd
if count < 3 {
http.Error(w, "temporary error", http.StatusServiceUnavailable)
return
}
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
blobData, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer server.Close()
clientDir := t.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download with retry failed: %v", err)
}
verifyBlob(t, clientDir, blob, data)
// Should have made 3 requests (2 failures + 1 success)
if requestCount.Load() < 3 {
t.Errorf("Expected at least 3 requests for retry, got %d", requestCount.Load())
}
}
func TestDownloadWithAuth(t *testing.T) {
serverDir := t.TempDir()
blob, data := createTestBlob(t, serverDir, 1024)
var authCalled atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Require auth
auth := r.Header.Get("Authorization")
if auth != "Bearer valid-token" {
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:pull"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
blobData, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer server.Close()
clientDir := t.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
authCalled.Store(true)
if challenge.Realm != "https://auth.example.com" {
t.Errorf("Wrong realm: %s", challenge.Realm)
}
if challenge.Service != "registry" {
t.Errorf("Wrong service: %s", challenge.Service)
}
return "valid-token", nil
},
})
if err != nil {
t.Fatalf("Download with auth failed: %v", err)
}
if !authCalled.Load() {
t.Error("GetToken was never called")
}
verifyBlob(t, clientDir, blob, data)
}
func TestDownloadSkipsExisting(t *testing.T) {
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1024)
// Pre-populate client dir
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(blob1.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, data1, 0o644); err != nil {
t.Fatal(err)
}
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
http.NotFound(w, r)
}))
defer server.Close()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Should not have made any requests (blob already exists)
if requestCount.Load() != 0 {
t.Errorf("Made %d requests, expected 0 (blob should be skipped)", requestCount.Load())
}
}
func TestDownloadResumeProgressTotal(t *testing.T) {
// Test that when resuming a download with some blobs already present:
// 1. Total reflects ALL blob sizes (not just remaining)
// 2. Completed starts at the size of already-downloaded blobs
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1000)
blob2, data2 := createTestBlob(t, serverDir, 2000)
blob3, data3 := createTestBlob(t, serverDir, 3000)
// Pre-populate client with blob1 and blob2 (simulating partial download)
clientDir := t.TempDir()
for _, b := range []struct {
blob Blob
data []byte
}{{blob1, data1}, {blob2, data2}} {
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, b.data, 0o644); err != nil {
t.Fatal(err)
}
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
var firstCompleted, firstTotal int64
var gotFirstProgress bool
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
Progress: func(completed, total int64) {
mu.Lock()
defer mu.Unlock()
if !gotFirstProgress {
firstCompleted = completed
firstTotal = total
gotFirstProgress = true
}
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Total should be sum of ALL blobs, not just blob3
expectedTotal := blob1.Size + blob2.Size + blob3.Size
if firstTotal != expectedTotal {
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
}
// First progress call should show already-completed bytes from blob1+blob2
expectedCompleted := blob1.Size + blob2.Size
if firstCompleted < expectedCompleted {
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
}
// Verify blob3 was downloaded
verifyBlob(t, clientDir, blob3, data3)
}
func TestDownloadDigestMismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return wrong data
w.WriteHeader(http.StatusOK)
w.Write([]byte("wrong data"))
}))
defer server.Close()
clientDir := t.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{{Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000", Size: 10}},
BaseURL: server.URL,
DestDir: clientDir,
})
if err == nil {
t.Fatal("Expected error for digest mismatch")
}
}
func TestUpload(t *testing.T) {
// Create test blobs
clientDir := t.TempDir()
blob1, _ := createTestBlob(t, clientDir, 1024)
blob2, _ := createTestBlob(t, clientDir, 2048)
var uploadedBlobs sync.Map
var uploadID atomic.Int32
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodHead:
http.NotFound(w, r)
case r.Method == http.MethodPost && r.URL.Path == "/v2/library/_/blobs/uploads/":
id := uploadID.Add(1)
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
w.WriteHeader(http.StatusAccepted)
case r.Method == http.MethodPatch:
session.recordPatch(w, r)
case r.Method == http.MethodPut:
digest := r.URL.Query().Get("digest")
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
var progressCalls atomic.Int32
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob1, blob2},
BaseURL: server.URL,
SrcDir: clientDir,
Concurrency: 2,
Progress: func(completed, total int64) {
progressCalls.Add(1)
},
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
// Verify both blobs were uploaded
if _, ok := uploadedBlobs.Load(blob1.Digest); !ok {
t.Error("Blob 1 not uploaded")
}
if _, ok := uploadedBlobs.Load(blob2.Digest); !ok {
t.Error("Blob 2 not uploaded")
}
if progressCalls.Load() == 0 {
t.Error("Progress callback never called")
}
}
func TestUploadWithRedirect(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 1024)
var uploadedBlobs sync.Map
var cdnCalled atomic.Bool
// CDN server (PATCH redirect target). PATCH is redirected to a PUT here,
// matching production: server issues 307 + Location to a presigned CDN URL,
// and the client re-uploads the part body via PUT.
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cdnCalled.Store(true)
if r.Method == http.MethodPut {
data, _ := io.ReadAll(r.Body)
// Stash the body keyed by the path so the main server can pick it
// up at finalize time.
uploadedBlobs.Store(r.URL.Path, data)
w.WriteHeader(http.StatusCreated)
}
}))
defer cdn.Close()
var serverURL string
var uploadID atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
id := uploadID.Add(1)
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
// Redirect PATCH body to CDN, mirroring server behavior
cdnURL := cdn.URL + r.URL.Path
w.Header().Set("Docker-Upload-Location", r.URL.Path)
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
case http.MethodPut:
// Finalize: copy body the CDN received under this session path
// to the uploadedBlobs map keyed by digest.
digest := r.URL.Query().Get("digest")
if v, ok := uploadedBlobs.Load(r.URL.Path); ok {
uploadedBlobs.Store(digest, v)
uploadedBlobs.Delete(r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload with redirect failed: %v", err)
}
if !cdnCalled.Load() {
t.Error("CDN was never called (redirect not followed)")
}
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
t.Error("Blob not uploaded to CDN")
}
}
func TestUploadWithAuth(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 1024)
var uploadedBlobs sync.Map
var authCalled atomic.Bool
var uploadID atomic.Int32
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Require auth for all requests
auth := r.Header.Get("Authorization")
if auth != "Bearer valid-token" {
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:push"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
id := uploadID.Add(1)
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
session.recordPatch(w, r)
case http.MethodPut:
digest := r.URL.Query().Get("digest")
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
authCalled.Store(true)
return "valid-token", nil
},
})
if err != nil {
t.Fatalf("Upload with auth failed: %v", err)
}
if !authCalled.Load() {
t.Error("GetToken was never called")
}
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
t.Error("Blob not uploaded")
}
}
func TestUploadSkipsExisting(t *testing.T) {
clientDir := t.TempDir()
blob1, _ := createTestBlob(t, clientDir, 1024)
var headChecked atomic.Bool
var putCalled atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
// HEAD check for blob existence - return 200 OK to indicate blob exists
headChecked.Store(true)
w.WriteHeader(http.StatusOK)
case http.MethodPost:
http.NotFound(w, r)
case http.MethodPut:
putCalled.Store(true)
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob1},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
// Verify HEAD check was used
if !headChecked.Load() {
t.Error("HEAD check was never made")
}
// Should not have attempted PUT (blob already exists)
if putCalled.Load() {
t.Error("PUT was called even though blob exists (HEAD returned 200)")
}
t.Log("HEAD-based existence check verified")
}
// TestUploadWithCustomRepository verifies that custom repository paths are used
func TestUploadWithCustomRepository(t *testing.T) {
clientDir := t.TempDir()
blob1, _ := createTestBlob(t, clientDir, 1024)
var headPath, postPath string
var mu sync.Mutex
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
mu.Lock()
headPath = r.URL.Path
mu.Unlock()
w.WriteHeader(http.StatusNotFound) // Blob doesn't exist
case http.MethodPost:
mu.Lock()
postPath = r.URL.Path
mu.Unlock()
w.Header().Set("Location", fmt.Sprintf("%s/v2/myorg/mymodel/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
session.recordPatch(w, r)
case http.MethodPut:
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob1},
BaseURL: server.URL,
SrcDir: clientDir,
Repository: "myorg/mymodel", // Custom repository
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
mu.Lock()
defer mu.Unlock()
// Verify HEAD used custom repository path
expectedHeadPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob1.Digest)
if headPath != expectedHeadPath {
t.Errorf("HEAD path mismatch: got %s, want %s", headPath, expectedHeadPath)
}
// Verify POST used custom repository path
expectedPostPath := "/v2/myorg/mymodel/blobs/uploads/"
if postPath != expectedPostPath {
t.Errorf("POST path mismatch: got %s, want %s", postPath, expectedPostPath)
}
t.Logf("Custom repository paths verified: HEAD=%s, POST=%s", headPath, postPath)
}
// TestDownloadWithCustomRepository verifies that custom repository paths are used
func TestDownloadWithCustomRepository(t *testing.T) {
serverDir := t.TempDir()
blob, data := createTestBlob(t, serverDir, 1024)
var requestPath string
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestPath = r.URL.Path
mu.Unlock()
// Serve blob from any path
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
blobData, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.WriteHeader(http.StatusOK)
w.Write(blobData)
}))
defer server.Close()
clientDir := t.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
Repository: "myorg/mymodel", // Custom repository
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
verifyBlob(t, clientDir, blob, data)
mu.Lock()
defer mu.Unlock()
// Verify request used custom repository path
expectedPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob.Digest)
if requestPath != expectedPath {
t.Errorf("Request path mismatch: got %s, want %s", requestPath, expectedPath)
}
t.Logf("Custom repository path verified: %s", requestPath)
}
func TestDigestToPath(t *testing.T) {
tests := []struct {
input string
want string
}{
{"sha256:abc123", "sha256-abc123"},
{"sha256-abc123", "sha256-abc123"},
{"other", "other"},
}
for _, tt := range tests {
got := digestToPath(tt.input)
if got != tt.want {
t.Errorf("digestToPath(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestParseAuthChallenge(t *testing.T) {
tests := []struct {
input string
want AuthChallenge
}{
{
input: `Bearer realm="https://auth.example.com/token",service="registry",scope="repository:library/test:pull"`,
want: AuthChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repository:library/test:pull",
},
},
{
input: `Bearer realm="https://auth.example.com"`,
want: AuthChallenge{
Realm: "https://auth.example.com",
},
},
}
for _, tt := range tests {
got := parseAuthChallenge(tt.input)
if got.Realm != tt.want.Realm {
t.Errorf("parseAuthChallenge(%q).Realm = %q, want %q", tt.input, got.Realm, tt.want.Realm)
}
if got.Service != tt.want.Service {
t.Errorf("parseAuthChallenge(%q).Service = %q, want %q", tt.input, got.Service, tt.want.Service)
}
if got.Scope != tt.want.Scope {
t.Errorf("parseAuthChallenge(%q).Scope = %q, want %q", tt.input, got.Scope, tt.want.Scope)
}
}
}
func verifyBlob(t *testing.T, dir string, blob Blob, expected []byte) {
t.Helper()
path := filepath.Join(dir, digestToPath(blob.Digest))
data, err := os.ReadFile(path)
if err != nil {
t.Errorf("Failed to read %s: %v", blob.Digest[:19], err)
return
}
if len(data) != len(expected) {
t.Errorf("Size mismatch for %s: got %d, want %d", blob.Digest[:19], len(data), len(expected))
return
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
if digest != blob.Digest {
t.Errorf("Digest mismatch for %s: got %s", blob.Digest[:19], digest[:19])
}
}
// ==================== Parallelism Tests ====================
func TestDownloadParallelism(t *testing.T) {
// Create many blobs to test parallelism
serverDir := t.TempDir()
numBlobs := 10
blobs := make([]Blob, numBlobs)
blobData := make([][]byte, numBlobs)
for i := range numBlobs {
blobs[i], blobData[i] = createTestBlob(t, serverDir, 1024+i*100)
}
var activeRequests atomic.Int32
var maxConcurrent atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
current := activeRequests.Add(1)
defer activeRequests.Add(-1)
// Track max concurrent requests
for {
old := maxConcurrent.Load()
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
break
}
}
// Simulate network latency to ensure parallelism is visible
time.Sleep(50 * time.Millisecond)
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
start := time.Now()
err := Download(context.Background(), DownloadOptions{
Blobs: blobs,
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 4,
BodyConcurrency: 4,
})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify all blobs downloaded
for i, blob := range blobs {
verifyBlob(t, clientDir, blob, blobData[i])
}
// Verify parallelism was used
if maxConcurrent.Load() < 2 {
t.Errorf("Max concurrent requests was %d, expected at least 2 for parallelism", maxConcurrent.Load())
}
// With 10 blobs at 50ms each, sequential would take ~500ms
// Parallel with 4 workers should take ~150ms (relax to 1s for CI variance)
if elapsed > time.Second {
t.Errorf("Downloads took %v, expected faster with parallelism", elapsed)
}
t.Logf("Downloaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
}
func TestUploadParallelism(t *testing.T) {
clientDir := t.TempDir()
numBlobs := 10
blobs := make([]Blob, numBlobs)
for i := range numBlobs {
blobs[i], _ = createTestBlob(t, clientDir, 1024+i*100)
}
var activeRequests atomic.Int32
var maxConcurrent atomic.Int32
var uploadedBlobs sync.Map
var uploadID atomic.Int32
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
current := activeRequests.Add(1)
defer activeRequests.Add(-1)
// Track max concurrent
for {
old := maxConcurrent.Load()
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
break
}
}
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
id := uploadID.Add(1)
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
time.Sleep(50 * time.Millisecond) // simulate upload time on body chunk
session.recordPatch(w, r)
case http.MethodPut:
digest := r.URL.Query().Get("digest")
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
start := time.Now()
err := Upload(context.Background(), UploadOptions{
Blobs: blobs,
BaseURL: server.URL,
SrcDir: clientDir,
Concurrency: 4,
})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
// Verify all blobs uploaded
for _, blob := range blobs {
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
t.Errorf("Blob %s not uploaded", blob.Digest[:19])
}
}
if maxConcurrent.Load() < 2 {
t.Errorf("Max concurrent requests was %d, expected at least 2", maxConcurrent.Load())
}
t.Logf("Uploaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
}
// ==================== Stall Detection Test ====================
func TestDownloadStallDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stall detection test in short mode")
}
serverDir := t.TempDir()
blob, _ := createTestBlob(t, serverDir, 10*1024) // 10KB
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := requestCount.Add(1)
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
if count == 1 {
// First request: send partial data then stall
w.Write(data[:1024]) // Send first 1KB
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Stall for longer than stall timeout (test uses 200ms)
time.Sleep(500 * time.Millisecond)
return
}
// Subsequent requests: send full data
w.Write(data)
}))
defer func() {
server.CloseClientConnections()
server.Close()
}()
clientDir := t.TempDir()
start := time.Now()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
StallTimeout: 200 * time.Millisecond, // Short timeout for testing
})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Should have retried after stall detection
if requestCount.Load() < 2 {
t.Errorf("Expected at least 2 requests (stall + retry), got %d", requestCount.Load())
}
// Should complete quickly with short stall timeout
if elapsed > 3*time.Second {
t.Errorf("Download took %v, stall detection should have triggered faster", elapsed)
}
t.Logf("Stall detection worked: %d requests in %v", requestCount.Load(), elapsed)
}
// ==================== Context Cancellation Tests ====================
func TestDownloadCancellation(t *testing.T) {
serverDir := t.TempDir()
blob, _ := createTestBlob(t, serverDir, 100*1024) // 100KB (smaller for faster test)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, _ := os.ReadFile(path)
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
// Send data slowly
for i := 0; i < len(data); i += 1024 {
end := i + 1024
if end > len(data) {
end = len(data)
}
w.Write(data[i:end])
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(5 * time.Millisecond)
}
}))
defer func() {
server.CloseClientConnections()
server.Close()
}()
clientDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
// Cancel after 50ms
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
start := time.Now()
err := Download(ctx, DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
elapsed := time.Since(start)
if err == nil {
t.Fatal("Expected error from cancellation")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got: %v", err)
}
// Should cancel quickly, not wait for full download
if elapsed > 500*time.Millisecond {
t.Errorf("Cancellation took %v, expected faster response", elapsed)
}
t.Logf("Cancellation worked in %v", elapsed)
}
func TestUploadCancellation(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 100*1024) // 100KB (smaller for faster test)
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
// Read slowly so the cancellation has time to interrupt the body upload.
buf := make([]byte, 1024)
for {
_, err := r.Body.Read(buf)
if err != nil {
break
}
time.Sleep(5 * time.Millisecond)
}
w.Header().Set("Docker-Upload-Location", r.URL.Path)
w.WriteHeader(http.StatusAccepted)
case http.MethodPut:
w.WriteHeader(http.StatusCreated)
}
}))
defer func() {
server.CloseClientConnections()
server.Close()
}()
serverURL = server.URL
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
start := time.Now()
err := Upload(ctx, UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
elapsed := time.Since(start)
if err == nil {
t.Fatal("Expected error from cancellation")
}
if elapsed > 500*time.Millisecond {
t.Errorf("Cancellation took %v, expected faster", elapsed)
}
t.Logf("Upload cancellation worked in %v", elapsed)
}
// ==================== Progress Tracking Tests ====================
func TestProgressTracking(t *testing.T) {
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 5000)
blob2, data2 := createTestBlob(t, serverDir, 3000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, _ := os.ReadFile(path)
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
var progressHistory []struct{ completed, total int64 }
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1, // Sequential to make progress predictable
Progress: func(completed, total int64) {
mu.Lock()
progressHistory = append(progressHistory, struct{ completed, total int64 }{completed, total})
mu.Unlock()
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
verifyBlob(t, clientDir, blob1, data1)
verifyBlob(t, clientDir, blob2, data2)
mu.Lock()
defer mu.Unlock()
if len(progressHistory) == 0 {
t.Fatal("No progress callbacks received")
}
// Total should always be sum of blob sizes
expectedTotal := blob1.Size + blob2.Size
for _, p := range progressHistory {
if p.total != expectedTotal {
t.Errorf("Total changed during download: got %d, want %d", p.total, expectedTotal)
}
}
// Completed should be monotonically increasing
var lastCompleted int64
for _, p := range progressHistory {
if p.completed < lastCompleted {
t.Errorf("Progress went backwards: %d -> %d", lastCompleted, p.completed)
}
lastCompleted = p.completed
}
// Final completed should equal total
final := progressHistory[len(progressHistory)-1]
if final.completed != expectedTotal {
t.Errorf("Final completed %d != total %d", final.completed, expectedTotal)
}
t.Logf("Progress tracked correctly: %d callbacks, final %d/%d", len(progressHistory), final.completed, final.total)
}
// ==================== Edge Cases ====================
func TestDownloadEmptyBlobList(t *testing.T) {
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{},
BaseURL: "http://unused",
DestDir: t.TempDir(),
})
if err != nil {
t.Errorf("Expected no error for empty blob list, got: %v", err)
}
}
func TestUploadEmptyBlobList(t *testing.T) {
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{},
BaseURL: "http://unused",
SrcDir: t.TempDir(),
})
if err != nil {
t.Errorf("Expected no error for empty blob list, got: %v", err)
}
}
func TestUploadRetryOnFailure(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 1024)
var patchCount atomic.Int32
var uploadedBlobs sync.Map
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
count := patchCount.Add(1)
if count < 3 {
// Fail first 2 PATCH attempts to exercise the retry path
io.Copy(io.Discard, r.Body)
http.Error(w, "server error", http.StatusInternalServerError)
return
}
session.recordPatch(w, r)
case http.MethodPut:
digest := r.URL.Query().Get("digest")
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload with retry failed: %v", err)
}
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
t.Error("Blob not uploaded after retry")
}
if patchCount.Load() < 3 {
t.Errorf("Expected at least 3 PATCH attempts, got %d", patchCount.Load())
}
}
// TestProgressRollback verifies that progress is rolled back on retry
func TestProgressRollback(t *testing.T) {
content := []byte("test content for rollback test")
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
blob := Blob{Digest: digest, Size: int64(len(content))}
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, content, 0o644); err != nil {
t.Fatal(err)
}
var patchCount atomic.Int32
var progressValues []int64
var mu sync.Mutex
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
// Read some bytes (so the client reports progress) before failing,
// to exercise the rollback-on-retry path.
count := patchCount.Add(1)
if count < 3 {
io.CopyN(io.Discard, r.Body, 10)
io.Copy(io.Discard, r.Body)
http.Error(w, "server error", http.StatusInternalServerError)
return
}
session.recordPatch(w, r)
case http.MethodPut:
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
Progress: func(completed, total int64) {
mu.Lock()
progressValues = append(progressValues, completed)
mu.Unlock()
},
})
if err != nil {
t.Fatalf("Upload with retry failed: %v", err)
}
// Check that progress was rolled back (should have negative values or drops)
mu.Lock()
defer mu.Unlock()
// Final progress should equal blob size
if len(progressValues) > 0 {
final := progressValues[len(progressValues)-1]
if final != blob.Size {
t.Errorf("Final progress %d != blob size %d", final, blob.Size)
}
}
t.Logf("Progress rollback test: %d progress callbacks", len(progressValues))
}
// TestUserAgentHeader verifies User-Agent header is set on requests
func TestUserAgentHeader(t *testing.T) {
content := []byte("test content")
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
blob := Blob{Digest: digest, Size: int64(len(content))}
destDir := t.TempDir()
var userAgents []string
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
ua := r.Header.Get("User-Agent")
userAgents = append(userAgents, ua)
mu.Unlock()
if r.Method == http.MethodGet {
w.Write(content)
}
}))
defer server.Close()
// Test with custom User-Agent
customUA := "test-agent/1.0"
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: destDir,
UserAgent: customUA,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
mu.Lock()
defer mu.Unlock()
// Verify custom User-Agent was used
for _, ua := range userAgents {
if ua != customUA {
t.Errorf("User-Agent %q != expected %q", ua, customUA)
}
}
t.Logf("User-Agent header test: %d requests with correct User-Agent", len(userAgents))
}
// TestDefaultUserAgent verifies default User-Agent is used when not specified
func TestDefaultUserAgent(t *testing.T) {
content := []byte("test content")
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
blob := Blob{Digest: digest, Size: int64(len(content))}
destDir := t.TempDir()
var userAgent string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAgent = r.Header.Get("User-Agent")
if r.Method == http.MethodGet {
w.Write(content)
}
}))
defer server.Close()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: destDir,
// No UserAgent specified - should use default
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
if userAgent == "" {
t.Error("User-Agent header was empty")
}
if userAgent != defaultUserAgent {
t.Errorf("Default User-Agent %q != expected %q", userAgent, defaultUserAgent)
}
}
// TestManifestPush verifies that manifest is pushed after blobs
func TestManifestPush(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 1000)
testManifest := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
testRepo := "library/test-model"
testRef := "latest"
var manifestReceived []byte
var manifestPath string
var manifestContentType string
var serverURL string
session := newChunkedSession()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodHead:
http.NotFound(w, r)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads"):
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case r.Method == http.MethodPatch:
session.recordPatch(w, r)
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/"):
// Finalize the chunked blob upload
w.WriteHeader(http.StatusCreated)
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/"):
manifestPath = r.URL.Path
manifestContentType = r.Header.Get("Content-Type")
manifestReceived, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
Manifest: testManifest,
ManifestRef: testRef,
Repository: testRepo,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
// Verify manifest was pushed
if manifestReceived == nil {
t.Fatal("Manifest was not received by server")
}
if !bytes.Equal(manifestReceived, testManifest) {
t.Errorf("Manifest content mismatch: got %s, want %s", manifestReceived, testManifest)
}
expectedPath := fmt.Sprintf("/v2/%s/manifests/%s", testRepo, testRef)
if manifestPath != expectedPath {
t.Errorf("Manifest path mismatch: got %s, want %s", manifestPath, expectedPath)
}
if manifestContentType != "application/vnd.docker.distribution.manifest.v2+json" {
t.Errorf("Manifest content type mismatch: got %s", manifestContentType)
}
t.Logf("Manifest push test passed: received %d bytes at %s", len(manifestReceived), manifestPath)
}
// ==================== Throughput Benchmarks ====================
func BenchmarkDownloadThroughput(b *testing.B) {
// Create test data - 1MB blob
data := make([]byte, 1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(len(data))}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for range b.N {
clientDir := b.TempDir()
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
})
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkUploadThroughput(b *testing.B) {
// Create test data - 1MB blob
data := make([]byte, 1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(len(data))}
// Create source file once
srcDir := b.TempDir()
path := filepath.Join(srcDir, digestToPath(digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
b.Fatal(err)
}
if err := os.WriteFile(path, data, 0o644); err != nil {
b.Fatal(err)
}
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPut:
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
serverURL = server.URL
b.SetBytes(int64(len(data)))
b.ResetTimer()
for range b.N {
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: srcDir,
Concurrency: 1,
})
if err != nil {
b.Fatal(err)
}
}
}
// ==================== Resume Tests ====================
func TestResumeFromPartialFile(t *testing.T) {
// Create a blob large enough for resume (>= resumeThreshold)
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var rangeHeader string
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
mu.Lock()
rangeHeader = r.Header.Get("Range")
mu.Unlock()
rng := r.Header.Get("Range")
if rng != "" {
// Parse "bytes=N-"
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start > 0 && start < int64(blobSize) {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file (first half)
partialSize := blobSize / 2
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:partialSize], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Resume download failed: %v", err)
}
// Verify Range header was sent
mu.Lock()
if rangeHeader == "" {
t.Error("Expected Range header for resume, got none")
} else {
expected := fmt.Sprintf("bytes=%d-", partialSize)
if rangeHeader != expected {
t.Errorf("Range header = %q, want %q", rangeHeader, expected)
}
}
mu.Unlock()
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
if len(finalData) != blobSize {
t.Errorf("Final file size = %d, want %d", len(finalData), blobSize)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeCorruptPartialFile(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
rng := r.Header.Get("Range")
if rng != "" {
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start > 0 && start < int64(blobSize) {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file with CORRUPT data
partialSize := blobSize / 2
corruptData := make([]byte, partialSize)
for i := range corruptData {
corruptData[i] = 0xFF // All 0xFF — definitely wrong
}
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", corruptData, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
// First attempt resumes with corrupt data → hash mismatch → retry.
// Retry should clean up .tmp and re-download fully.
if err != nil {
t.Fatalf("Download with corrupt partial file failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch after corrupt resume recovery")
}
}
func TestResumePartialFileLargerThanBlob(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create .tmp file LARGER than expected blob
oversizedData := make([]byte, blobSize+1000)
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", oversizedData, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download with oversized .tmp failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeBelowThreshold(t *testing.T) {
// Blob below resume threshold should NOT attempt resume
blobSize := 1024 // Well below resumeThreshold
data := make([]byte, blobSize)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var gotRange atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
if r.Header.Get("Range") != "" {
gotRange.Store(true)
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create a partial .tmp file
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
if gotRange.Load() {
t.Error("Range header sent for blob below resume threshold — should not attempt resume")
}
// Verify final file
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumeServerDoesNotSupportRange(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
// Ignore Range header — always return full content with 200
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create partial .tmp file
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed when server doesn't support Range: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
finalHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", finalHash) != digest {
t.Error("Final file hash mismatch")
}
}
func TestResumePartialFileExactSize(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
return
}
requestCount.Add(1)
rng := r.Header.Get("Range")
if rng != "" {
var start int64
fmt.Sscanf(rng, "bytes=%d-", &start)
if start >= int64(blobSize) {
// Nothing to send
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return
}
if start > 0 {
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
w.WriteHeader(http.StatusPartialContent)
w.Write(data[start:])
return
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
clientDir := t.TempDir()
// Pre-create .tmp file with exact correct content (full size)
// This simulates a download that completed but wasn't renamed
dest := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(dest), 0o755)
os.WriteFile(dest+".tmp", data, 0o644)
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
DestDir: clientDir,
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify final file is correct
finalData, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
resumeHash := sha256.Sum256(finalData)
if fmt.Sprintf("sha256:%x", resumeHash) != digest {
t.Error("Final file hash mismatch")
}
}
// ==================== Chunked Upload Tests ====================
// chunkedUploadServer creates a test server that implements the OCI chunked
// upload protocol: POST → PATCH* (with Content-Range) → PUT (finalize).
type chunkedUploadServer struct {
t *testing.T
mu sync.Mutex
parts map[int][]byte // part offset -> received data
patchCount int
finalized bool
finalDigest string
finalEtag string
patchHandler func(w http.ResponseWriter, r *http.Request) // optional override
uploadCounter int
serverURL *string
}
func newChunkedUploadServer(t *testing.T) *chunkedUploadServer {
return &chunkedUploadServer{
t: t,
parts: make(map[int][]byte),
}
}
func (s *chunkedUploadServer) handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodHead:
http.NotFound(w, r)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/uploads"):
s.mu.Lock()
s.uploadCounter++
id := s.uploadCounter
s.mu.Unlock()
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", *s.serverURL, id))
w.WriteHeader(http.StatusAccepted)
case r.Method == http.MethodPatch:
if s.patchHandler != nil {
s.patchHandler(w, r)
return
}
s.defaultPatchHandler(w, r)
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/uploads"):
s.mu.Lock()
s.finalized = true
s.finalDigest = r.URL.Query().Get("digest")
s.finalEtag = r.URL.Query().Get("etag")
s.mu.Unlock()
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}
}
func (s *chunkedUploadServer) defaultPatchHandler(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
s.patchCount++
patchNum := s.patchCount
s.mu.Unlock()
cr := r.Header.Get("Content-Range")
if cr == "" {
http.Error(w, "missing Content-Range", http.StatusBadRequest)
return
}
var start, end int64
fmt.Sscanf(cr, "%d-%d", &start, &end)
data, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.mu.Lock()
s.parts[int(start)] = data
s.mu.Unlock()
w.Header().Set("Docker-Upload-Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/session-%d", *s.serverURL, patchNum+1))
w.WriteHeader(http.StatusAccepted)
}
func (s *chunkedUploadServer) reassemble(totalSize int) []byte {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]byte, totalSize)
for offset, data := range s.parts {
copy(result[offset:], data)
}
return result
}
func TestComputeParts(t *testing.T) {
tests := []struct {
name string
totalSize int64
wantParts int
wantFirst int64
}{
{
name: "1GB blob — clamped to min part size",
totalSize: 1 << 30,
wantParts: int((1<<30 + minUploadPartSize - 1) / minUploadPartSize),
wantFirst: minUploadPartSize,
},
{
name: "5GB blob — 16 parts",
totalSize: 5 << 30,
wantParts: 16,
wantFirst: 5 << 30 / 16,
},
{
name: "20GB blob — clamped to max part size",
totalSize: 20 << 30,
wantParts: int((20<<30 + maxUploadPartSize - 1) / maxUploadPartSize),
wantFirst: maxUploadPartSize,
},
{
name: "exactly min part size",
totalSize: minUploadPartSize,
wantParts: 1,
wantFirst: minUploadPartSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parts := computeParts(tt.totalSize)
if len(parts) != tt.wantParts {
t.Errorf("computeParts(%d) = %d parts, want %d", tt.totalSize, len(parts), tt.wantParts)
}
if len(parts) > 0 && parts[0].size != tt.wantFirst {
t.Errorf("first part size = %d, want %d", parts[0].size, tt.wantFirst)
}
// Verify parts cover entire blob with no gaps
var total int64
for i, p := range parts {
if p.offset != total {
t.Errorf("part %d offset = %d, want %d", i, p.offset, total)
}
if p.n != i {
t.Errorf("part %d n = %d, want %d", i, p.n, i)
}
total += p.size
}
if total != tt.totalSize {
t.Errorf("total part sizes = %d, want %d", total, tt.totalSize)
}
})
}
}
func TestChunkedUploadBasic(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 7) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, data, 0o644)
srv := newChunkedUploadServer(t)
var serverURL string
srv.serverURL = &serverURL
server := httptest.NewServer(srv.handler())
defer server.Close()
serverURL = server.URL
var progressCalls atomic.Int32
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
Progress: func(completed, total int64) {
progressCalls.Add(1)
},
})
if err != nil {
t.Fatalf("Chunked upload failed: %v", err)
}
reassembled := srv.reassemble(blobSize)
reassembledHash := sha256.Sum256(reassembled)
if fmt.Sprintf("sha256:%x", reassembledHash) != digest {
t.Error("Reassembled data hash mismatch")
}
srv.mu.Lock()
if !srv.finalized {
t.Error("Finalize PUT was not called")
}
if srv.finalDigest != digest {
t.Errorf("Finalize digest = %s, want %s", srv.finalDigest, digest)
}
if srv.finalEtag == "" {
t.Error("Finalize etag is empty")
}
if srv.patchCount == 0 {
t.Error("No PATCH requests were sent")
}
srv.mu.Unlock()
if progressCalls.Load() == 0 {
t.Error("Progress callback never called")
}
}
// TestSmallBlobUsesChunkedFlow verifies that even small blobs go through the
// PATCH+finalize chunked flow. Server-side redirect logic is gated on PATCH,
// so a single-PUT path could never trigger CDN redirection — every blob must
// use PATCH so the server has the chance to redirect.
func TestChunkedUploadCDNRedirect(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i * 7) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, data, 0o644)
cdnParts := make(map[string][]byte)
var cdnMu sync.Mutex
var cdnGotAuth atomic.Bool
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "" {
cdnGotAuth.Store(true)
}
cdnData, _ := io.ReadAll(r.Body)
cdnMu.Lock()
cdnParts[r.URL.Path] = cdnData
cdnMu.Unlock()
w.WriteHeader(http.StatusCreated)
}))
defer cdn.Close()
var serverURL string
var patchCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodHead:
http.NotFound(w, r)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/uploads"):
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case r.Method == http.MethodPatch:
n := patchCount.Add(1)
w.Header().Set("Docker-Upload-Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/session-%d", serverURL, n+1))
cdnPath := fmt.Sprintf("/cdn/part-%d", n)
http.Redirect(w, r, cdn.URL+cdnPath, http.StatusTemporaryRedirect)
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/uploads"):
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload with CDN redirect failed: %v", err)
}
cdnMu.Lock()
totalCDNBytes := 0
for _, d := range cdnParts {
totalCDNBytes += len(d)
}
cdnMu.Unlock()
if totalCDNBytes != blobSize {
t.Errorf("CDN received %d bytes, want %d", totalCDNBytes, blobSize)
}
if cdnGotAuth.Load() {
t.Error("CDN received Authorization header — should not be sent to CDN")
}
}
func TestChunkedUploadPartRetry(t *testing.T) {
blobSize := resumeThreshold + 1024
data := make([]byte, blobSize)
for i := range data {
data[i] = byte(i % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, data, 0o644)
var patchAttempts atomic.Int32
srv := newChunkedUploadServer(t)
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
attempt := patchAttempts.Add(1)
if attempt == 1 {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
srv.defaultPatchHandler(w, r)
}
var serverURL string
srv.serverURL = &serverURL
server := httptest.NewServer(srv.handler())
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload with retry failed: %v", err)
}
if patchAttempts.Load() < 2 {
t.Errorf("Expected at least 2 PATCH attempts, got %d", patchAttempts.Load())
}
reassembled := srv.reassemble(blobSize)
reassembledHash := sha256.Sum256(reassembled)
if fmt.Sprintf("sha256:%x", reassembledHash) != digest {
t.Error("Data integrity failed after retry")
}
}
// multiPartTestHelper creates an uploader with small part sizes so multi-part
// behavior can be tested with small blobs. Returns the uploader and blob data.
func multiPartTestHelper(t *testing.T, blobSize int, partSize int64, serverURL string) (*uploader, Blob, []byte) {
t.Helper()
data := make([]byte, blobSize)
for i := range data {
data[i] = byte((i*7 + 13) % 256)
}
h := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", h)
blob := Blob{Digest: digest, Size: int64(blobSize)}
clientDir := t.TempDir()
path := filepath.Join(clientDir, digestToPath(digest))
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, data, 0o644)
u := &uploader{
client: defaultClient,
baseURL: serverURL,
srcDir: clientDir,
userAgent: defaultUserAgent,
progress: newProgressTracker(int64(blobSize), nil),
makeParts: func(totalSize int64) []uploadPart {
return computePartsWithLimits(totalSize, 16, partSize, partSize*10)
},
}
return u, blob, data
}
func TestChunkedUploadMultiPartSessionURLChain(t *testing.T) {
// Use 10KB blobs with 2KB parts → 5 parts, exercising the URL chain
blobSize := 10240
partSize := int64(2048)
var patchURLs []string
var mu sync.Mutex
srv := newChunkedUploadServer(t)
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
patchURLs = append(patchURLs, r.URL.Path)
mu.Unlock()
srv.defaultPatchHandler(w, r)
}
var serverURL string
srv.serverURL = &serverURL
server := httptest.NewServer(srv.handler())
defer server.Close()
serverURL = server.URL
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
if err != nil {
t.Fatal(err)
}
defer f.Close()
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/init-1", server.URL)
n, err := u.putChunked(context.Background(), initURL, f, blob)
if err != nil {
t.Fatalf("putChunked failed: %v", err)
}
if n != int64(blobSize) {
t.Errorf("bytes written = %d, want %d", n, blobSize)
}
// Verify data integrity
reassembled := srv.reassemble(blobSize)
if !bytes.Equal(reassembled, data) {
t.Error("Reassembled data mismatch")
}
mu.Lock()
defer mu.Unlock()
// Should have 5 parts with distinct URLs
if len(patchURLs) != 5 {
t.Fatalf("Expected 5 PATCH requests, got %d", len(patchURLs))
}
// First PATCH uses the init URL
if !strings.Contains(patchURLs[0], "init-1") {
t.Errorf("First PATCH URL should contain init-1, got %s", patchURLs[0])
}
// Subsequent PATCHes should use session URLs from Docker-Upload-Location
for i := 1; i < len(patchURLs); i++ {
if patchURLs[i] == patchURLs[i-1] {
t.Errorf("PATCH %d used same URL as PATCH %d — chain broken", i, i-1)
}
if !strings.Contains(patchURLs[i], "session-") {
t.Errorf("PATCH %d URL should contain session-, got %s", i, patchURLs[i])
}
}
}
func TestChunkedUploadMultiPartDataIntegrity(t *testing.T) {
// Non-evenly-divisible: 10001 bytes with 3000-byte parts → 4 parts (3000+3000+3000+1001)
blobSize := 10001
partSize := int64(3000)
srv := newChunkedUploadServer(t)
var serverURL string
srv.serverURL = &serverURL
server := httptest.NewServer(srv.handler())
defer server.Close()
serverURL = server.URL
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
if err != nil {
t.Fatal(err)
}
defer f.Close()
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", server.URL)
_, err = u.putChunked(context.Background(), initURL, f, blob)
if err != nil {
t.Fatalf("putChunked failed: %v", err)
}
reassembled := srv.reassemble(blobSize)
if !bytes.Equal(reassembled, data) {
t.Error("Reassembled data mismatch with non-evenly-divisible parts")
}
srv.mu.Lock()
if srv.patchCount != 4 {
t.Errorf("Expected 4 PATCH requests, got %d", srv.patchCount)
}
srv.mu.Unlock()
}
func TestChunkedUploadMultiPartProgressRollback(t *testing.T) {
blobSize := 6000
partSize := int64(2000) // 3 parts
var patchAttempts atomic.Int32
srv := newChunkedUploadServer(t)
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
attempt := patchAttempts.Add(1)
// Fail the second PATCH attempt (part 1, first try). Drain the body
// before erroring so the server sends 100 Continue (under Expect:
// 100-continue) and the client uploads the body — that's what makes
// the progress rollback observable.
if attempt == 2 {
io.Copy(io.Discard, r.Body)
http.Error(w, "server error", http.StatusInternalServerError)
return
}
srv.defaultPatchHandler(w, r)
}
var serverURL string
srv.serverURL = &serverURL
server := httptest.NewServer(srv.handler())
defer server.Close()
serverURL = server.URL
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
// Track progress
var progressValues []int64
var mu sync.Mutex
u.progress = newProgressTracker(int64(blobSize), func(completed, total int64) {
mu.Lock()
progressValues = append(progressValues, completed)
mu.Unlock()
})
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
if err != nil {
t.Fatal(err)
}
defer f.Close()
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", server.URL)
_, err = u.putChunked(context.Background(), initURL, f, blob)
if err != nil {
t.Fatalf("putChunked failed: %v", err)
}
// Verify data integrity despite retry
reassembled := srv.reassemble(blobSize)
if !bytes.Equal(reassembled, data) {
t.Error("Data mismatch after retry")
}
// Verify progress had a rollback (decrease) then recovered
mu.Lock()
defer mu.Unlock()
hadDecrease := false
for i := 1; i < len(progressValues); i++ {
if progressValues[i] < progressValues[i-1] {
hadDecrease = true
break
}
}
if !hadDecrease {
t.Error("Expected progress to decrease (rollback) during retry, but it was monotonic")
}
// Final should equal blob size
if len(progressValues) > 0 && progressValues[len(progressValues)-1] != int64(blobSize) {
t.Errorf("Final progress = %d, want %d", progressValues[len(progressValues)-1], blobSize)
}
}
// ==================== v2 direct-upload extension tests ====================
// TestV2InitRequestShape verifies the init POST advertises the v2 capability
// with the expected query parameter and headers, and that the request body
// is empty.
func TestV2InitRequestShape(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 4096)
var sawDigestQuery, sawCapHeader, sawSizeHeader string
var bodyLen int
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
sawDigestQuery = r.URL.Query().Get("digest")
sawCapHeader = r.Header.Get("X-Redirect-Uploads")
sawSizeHeader = r.Header.Get("X-Content-Length")
body, _ := io.ReadAll(r.Body)
bodyLen = len(body)
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
session.recordPatch(w, r)
case http.MethodPut:
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
if sawDigestQuery != blob.Digest {
t.Errorf("init POST ?digest= = %q, want %q", sawDigestQuery, blob.Digest)
}
if sawCapHeader != "2" {
t.Errorf("init POST X-Redirect-Uploads = %q, want %q", sawCapHeader, "2")
}
if sawSizeHeader != fmt.Sprintf("%d", blob.Size) {
t.Errorf("init POST X-Content-Length = %q, want %q", sawSizeHeader, fmt.Sprintf("%d", blob.Size))
}
if bodyLen != 0 {
t.Errorf("init POST body length = %d, want 0", bodyLen)
}
}
// TestV2DirectUpload verifies the v2 happy path: server returns
// X-Direct-Upload-URL + X-Signed-Header-X-Amz-Checksum-Sha256, the client
// PUTs body to the direct URL with the forwarded checksum header, then
// commits via a bodyless PUT to the session URL.
func TestV2DirectUpload(t *testing.T) {
clientDir := t.TempDir()
blob, data := createTestBlob(t, clientDir, 8192)
var (
cdnReceived []byte
cdnChecksumHeader string
cdnAuthHeader string
cdnHits atomic.Int32
)
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cdnHits.Add(1)
if r.Method != http.MethodPut {
http.Error(w, "want PUT", http.StatusMethodNotAllowed)
return
}
cdnChecksumHeader = r.Header.Get("X-Amz-Checksum-Sha256")
cdnAuthHeader = r.Header.Get("Authorization")
cdnReceived, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
}))
defer cdn.Close()
// Compute the checksum the registry would sign — base64 of the SHA-256
// binary digest. The mock just hands the value back to the client; the
// client forwards it to the CDN.
hexDigest := strings.TrimPrefix(blob.Digest, "sha256:")
digestBytes := make([]byte, len(hexDigest)/2)
for i := range digestBytes {
fmt.Sscanf(hexDigest[i*2:i*2+2], "%02x", &digestBytes[i])
}
expectedChecksum := base64.StdEncoding.EncodeToString(digestBytes)
var (
commitDigest atomic.Value
commitBody atomic.Int32
)
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
directURL := cdn.URL + "/upload/" + blob.Digest
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.Header().Set("X-Direct-Upload-URL", directURL)
w.Header().Set("X-Signed-Header-X-Amz-Checksum-Sha256", expectedChecksum)
w.WriteHeader(http.StatusAccepted)
case http.MethodPut:
commitDigest.Store(r.URL.Query().Get("digest"))
body, _ := io.ReadAll(r.Body)
commitBody.Store(int32(len(body)))
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
if got := cdnHits.Load(); got != 1 {
t.Errorf("CDN hits = %d, want 1", got)
}
if !bytes.Equal(cdnReceived, data) {
t.Errorf("CDN body length = %d, want %d", len(cdnReceived), len(data))
}
if cdnChecksumHeader != expectedChecksum {
t.Errorf("CDN x-amz-checksum-sha256 = %q, want %q", cdnChecksumHeader, expectedChecksum)
}
if cdnAuthHeader != "" {
t.Errorf("CDN Authorization = %q, want empty (presigned URL shouldn't carry auth)", cdnAuthHeader)
}
if got, _ := commitDigest.Load().(string); got != blob.Digest {
t.Errorf("commit ?digest= = %q, want %q", got, blob.Digest)
}
if commitBody.Load() != 0 {
t.Errorf("commit body length = %d, want 0", commitBody.Load())
}
}
// TestV2FallbackToChunked verifies that when the server returns a standard
// 202 without v2 extension headers, the client falls back to the chunked
// PATCH path. This exercises the vanilla Docker Registry compatibility.
func TestV2FallbackToChunked(t *testing.T) {
clientDir := t.TempDir()
blob, data := createTestBlob(t, clientDir, 8192)
var (
uploadedBlobs sync.Map
patchHit atomic.Int32
commitHit atomic.Int32
)
session := newChunkedSession()
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
// vanilla: standard Location only, no v2 extension headers
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
w.WriteHeader(http.StatusAccepted)
case http.MethodPatch:
patchHit.Add(1)
session.recordPatch(w, r)
case http.MethodPut:
commitHit.Add(1)
digest := r.URL.Query().Get("digest")
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
w.WriteHeader(http.StatusCreated)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
serverURL = server.URL
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
if patchHit.Load() < 1 {
t.Error("expected at least one PATCH (chunked fallback), got none")
}
if commitHit.Load() != 1 {
t.Errorf("commit PUT hits = %d, want 1", commitHit.Load())
}
if got, ok := uploadedBlobs.Load(blob.Digest); !ok {
t.Error("blob not uploaded")
} else if !bytes.Equal(got.([]byte), data) {
t.Errorf("uploaded body length = %d, want %d", len(got.([]byte)), len(data))
}
}
// TestV2BlobAlreadyExists verifies that a 201 Created response from the init
// POST short-circuits the upload — the server has matched our ?digest=
// against existing storage and there's nothing to upload.
func TestV2BlobAlreadyExists(t *testing.T) {
clientDir := t.TempDir()
blob, _ := createTestBlob(t, clientDir, 1024)
var (
postHits atomic.Int32
patchHits atomic.Int32
putHits atomic.Int32
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodHead:
http.NotFound(w, r)
case http.MethodPost:
postHits.Add(1)
// Server matched ?digest= against existing storage.
w.WriteHeader(http.StatusCreated)
case http.MethodPatch:
patchHits.Add(1)
w.WriteHeader(http.StatusAccepted)
case http.MethodPut:
putHits.Add(1)
w.WriteHeader(http.StatusCreated)
}
}))
defer server.Close()
err := Upload(context.Background(), UploadOptions{
Blobs: []Blob{blob},
BaseURL: server.URL,
SrcDir: clientDir,
})
if err != nil {
t.Fatalf("Upload failed: %v", err)
}
if postHits.Load() != 1 {
t.Errorf("init POST hits = %d, want 1", postHits.Load())
}
if patchHits.Load() != 0 {
t.Errorf("PATCH hits = %d, want 0 (blob existed; nothing to upload)", patchHits.Load())
}
if putHits.Load() != 0 {
t.Errorf("PUT hits = %d, want 0 (blob existed; nothing to upload)", putHits.Load())
}
}