2910 lines
78 KiB
Go
2910 lines
78 KiB
Go
package transfer
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// chunkedSession tracks accumulated PATCH body bytes for an upload session.
|
|
// Tests that mock the registry use it to handle the GGUF-style POST → PATCH →
|
|
// PUT-finalize flow without each test reimplementing the bookkeeping.
|
|
type chunkedSession struct {
|
|
mu sync.Mutex
|
|
sessions map[string]*bytes.Buffer
|
|
}
|
|
|
|
func newChunkedSession() *chunkedSession {
|
|
return &chunkedSession{sessions: make(map[string]*bytes.Buffer)}
|
|
}
|
|
|
|
// recordPatch reads the request body into the session buffer and writes a
|
|
// 202 Accepted response with Docker-Upload-Location pointing at the same
|
|
// session URL. Use this from a mock handler's PATCH branch.
|
|
func (c *chunkedSession) recordPatch(w http.ResponseWriter, r *http.Request) {
|
|
c.mu.Lock()
|
|
buf, ok := c.sessions[r.URL.Path]
|
|
if !ok {
|
|
buf = &bytes.Buffer{}
|
|
c.sessions[r.URL.Path] = buf
|
|
}
|
|
c.mu.Unlock()
|
|
io.Copy(buf, r.Body)
|
|
w.Header().Set("Docker-Upload-Location", r.URL.Path)
|
|
w.WriteHeader(http.StatusAccepted)
|
|
}
|
|
|
|
// finalize returns the bytes accumulated for the given session URL path.
|
|
// Use it from the PUT-finalize branch of a mock handler.
|
|
func (c *chunkedSession) finalize(sessionPath string) []byte {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if buf, ok := c.sessions[sessionPath]; ok {
|
|
return buf.Bytes()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createTestBlob creates a blob with deterministic content and returns its digest
|
|
func createTestBlob(t *testing.T, dir string, size int) (Blob, []byte) {
|
|
t.Helper()
|
|
|
|
// Create deterministic content
|
|
data := make([]byte, size)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
|
|
// Write to file
|
|
path := filepath.Join(dir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return Blob{Digest: digest, Size: int64(size)}, data
|
|
}
|
|
|
|
func TestDownload(t *testing.T) {
|
|
// Create test blobs on "server"
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1024)
|
|
blob2, data2 := createTestBlob(t, serverDir, 2048)
|
|
blob3, data3 := createTestBlob(t, serverDir, 512)
|
|
|
|
// Mock server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract digest from URL: /v2/library/_/blobs/sha256:...
|
|
digest := filepath.Base(r.URL.Path)
|
|
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Download to client dir
|
|
clientDir := t.TempDir()
|
|
|
|
var progressCalls atomic.Int32
|
|
var lastCompleted, lastTotal atomic.Int64
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2, blob3},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 2,
|
|
Progress: func(completed, total int64) {
|
|
progressCalls.Add(1)
|
|
lastCompleted.Store(completed)
|
|
lastTotal.Store(total)
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify files
|
|
verifyBlob(t, clientDir, blob1, data1)
|
|
verifyBlob(t, clientDir, blob2, data2)
|
|
verifyBlob(t, clientDir, blob3, data3)
|
|
|
|
// Verify progress was called
|
|
if progressCalls.Load() == 0 {
|
|
t.Error("Progress callback never called")
|
|
}
|
|
if lastTotal.Load() != blob1.Size+blob2.Size+blob3.Size {
|
|
t.Errorf("Wrong total: got %d, want %d", lastTotal.Load(), blob1.Size+blob2.Size+blob3.Size)
|
|
}
|
|
}
|
|
|
|
func TestDownloadWithRedirect(t *testing.T) {
|
|
// Create test blob on "CDN"
|
|
cdnDir := t.TempDir()
|
|
blob, data := createTestBlob(t, cdnDir, 1024)
|
|
|
|
// CDN server (the redirect target)
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Serve the blob content
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(cdnDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(blobData)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
// Registry server (redirects to CDN)
|
|
registry := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Redirect to CDN
|
|
cdnURL := cdn.URL + r.URL.Path
|
|
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
|
|
}))
|
|
defer registry.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: registry.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with redirect failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
}
|
|
|
|
func TestDownloadWithRetry(t *testing.T) {
|
|
// Create test blob
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
count := requestCount.Add(1)
|
|
|
|
// Fail first 2 attempts, succeed on 3rd
|
|
if count < 3 {
|
|
http.Error(w, "temporary error", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with retry failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
|
|
// Should have made 3 requests (2 failures + 1 success)
|
|
if requestCount.Load() < 3 {
|
|
t.Errorf("Expected at least 3 requests for retry, got %d", requestCount.Load())
|
|
}
|
|
}
|
|
|
|
func TestDownloadWithAuth(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var authCalled atomic.Bool
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Require auth
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "Bearer valid-token" {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:pull"`)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
|
|
authCalled.Store(true)
|
|
if challenge.Realm != "https://auth.example.com" {
|
|
t.Errorf("Wrong realm: %s", challenge.Realm)
|
|
}
|
|
if challenge.Service != "registry" {
|
|
t.Errorf("Wrong service: %s", challenge.Service)
|
|
}
|
|
return "valid-token", nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with auth failed: %v", err)
|
|
}
|
|
|
|
if !authCalled.Load() {
|
|
t.Error("GetToken was never called")
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
}
|
|
|
|
func TestDownloadSkipsExisting(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1024)
|
|
|
|
// Pre-populate client dir
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(blob1.Digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data1, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var requestCount atomic.Int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount.Add(1)
|
|
http.NotFound(w, r)
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Should not have made any requests (blob already exists)
|
|
if requestCount.Load() != 0 {
|
|
t.Errorf("Made %d requests, expected 0 (blob should be skipped)", requestCount.Load())
|
|
}
|
|
}
|
|
|
|
func TestDownloadResumeProgressTotal(t *testing.T) {
|
|
// Test that when resuming a download with some blobs already present:
|
|
// 1. Total reflects ALL blob sizes (not just remaining)
|
|
// 2. Completed starts at the size of already-downloaded blobs
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 1000)
|
|
blob2, data2 := createTestBlob(t, serverDir, 2000)
|
|
blob3, data3 := createTestBlob(t, serverDir, 3000)
|
|
|
|
// Pre-populate client with blob1 and blob2 (simulating partial download)
|
|
clientDir := t.TempDir()
|
|
for _, b := range []struct {
|
|
blob Blob
|
|
data []byte
|
|
}{{blob1, data1}, {blob2, data2}} {
|
|
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, b.data, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
var firstCompleted, firstTotal int64
|
|
var gotFirstProgress bool
|
|
var mu sync.Mutex
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2, blob3},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1,
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if !gotFirstProgress {
|
|
firstCompleted = completed
|
|
firstTotal = total
|
|
gotFirstProgress = true
|
|
}
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Total should be sum of ALL blobs, not just blob3
|
|
expectedTotal := blob1.Size + blob2.Size + blob3.Size
|
|
if firstTotal != expectedTotal {
|
|
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
|
|
}
|
|
|
|
// First progress call should show already-completed bytes from blob1+blob2
|
|
expectedCompleted := blob1.Size + blob2.Size
|
|
if firstCompleted < expectedCompleted {
|
|
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
|
|
}
|
|
|
|
// Verify blob3 was downloaded
|
|
verifyBlob(t, clientDir, blob3, data3)
|
|
}
|
|
|
|
func TestDownloadDigestMismatch(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Return wrong data
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("wrong data"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{{Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000", Size: 10}},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err == nil {
|
|
t.Fatal("Expected error for digest mismatch")
|
|
}
|
|
}
|
|
|
|
func TestUpload(t *testing.T) {
|
|
// Create test blobs
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
blob2, _ := createTestBlob(t, clientDir, 2048)
|
|
|
|
var uploadedBlobs sync.Map
|
|
var uploadID atomic.Int32
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && r.URL.Path == "/v2/library/_/blobs/uploads/":
|
|
id := uploadID.Add(1)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPatch:
|
|
session.recordPatch(w, r)
|
|
|
|
case r.Method == http.MethodPut:
|
|
digest := r.URL.Query().Get("digest")
|
|
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
var progressCalls atomic.Int32
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1, blob2},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Concurrency: 2,
|
|
Progress: func(completed, total int64) {
|
|
progressCalls.Add(1)
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify both blobs were uploaded
|
|
if _, ok := uploadedBlobs.Load(blob1.Digest); !ok {
|
|
t.Error("Blob 1 not uploaded")
|
|
}
|
|
if _, ok := uploadedBlobs.Load(blob2.Digest); !ok {
|
|
t.Error("Blob 2 not uploaded")
|
|
}
|
|
|
|
if progressCalls.Load() == 0 {
|
|
t.Error("Progress callback never called")
|
|
}
|
|
}
|
|
|
|
func TestUploadWithRedirect(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var uploadedBlobs sync.Map
|
|
var cdnCalled atomic.Bool
|
|
|
|
// CDN server (PATCH redirect target). PATCH is redirected to a PUT here,
|
|
// matching production: server issues 307 + Location to a presigned CDN URL,
|
|
// and the client re-uploads the part body via PUT.
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cdnCalled.Store(true)
|
|
if r.Method == http.MethodPut {
|
|
data, _ := io.ReadAll(r.Body)
|
|
// Stash the body keyed by the path so the main server can pick it
|
|
// up at finalize time.
|
|
uploadedBlobs.Store(r.URL.Path, data)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
var serverURL string
|
|
var uploadID atomic.Int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
id := uploadID.Add(1)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
// Redirect PATCH body to CDN, mirroring server behavior
|
|
cdnURL := cdn.URL + r.URL.Path
|
|
w.Header().Set("Docker-Upload-Location", r.URL.Path)
|
|
http.Redirect(w, r, cdnURL, http.StatusTemporaryRedirect)
|
|
|
|
case http.MethodPut:
|
|
// Finalize: copy body the CDN received under this session path
|
|
// to the uploadedBlobs map keyed by digest.
|
|
digest := r.URL.Query().Get("digest")
|
|
if v, ok := uploadedBlobs.Load(r.URL.Path); ok {
|
|
uploadedBlobs.Store(digest, v)
|
|
uploadedBlobs.Delete(r.URL.Path)
|
|
}
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with redirect failed: %v", err)
|
|
}
|
|
|
|
if !cdnCalled.Load() {
|
|
t.Error("CDN was never called (redirect not followed)")
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded to CDN")
|
|
}
|
|
}
|
|
|
|
func TestUploadWithAuth(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var uploadedBlobs sync.Map
|
|
var authCalled atomic.Bool
|
|
var uploadID atomic.Int32
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Require auth for all requests
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "Bearer valid-token" {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="https://auth.example.com",service="registry",scope="repository:library:push"`)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
id := uploadID.Add(1)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
session.recordPatch(w, r)
|
|
|
|
case http.MethodPut:
|
|
digest := r.URL.Query().Get("digest")
|
|
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
GetToken: func(ctx context.Context, challenge AuthChallenge) (string, error) {
|
|
authCalled.Store(true)
|
|
return "valid-token", nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with auth failed: %v", err)
|
|
}
|
|
|
|
if !authCalled.Load() {
|
|
t.Error("GetToken was never called")
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded")
|
|
}
|
|
}
|
|
|
|
func TestUploadSkipsExisting(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var headChecked atomic.Bool
|
|
var putCalled atomic.Bool
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
// HEAD check for blob existence - return 200 OK to indicate blob exists
|
|
headChecked.Store(true)
|
|
w.WriteHeader(http.StatusOK)
|
|
case http.MethodPost:
|
|
http.NotFound(w, r)
|
|
case http.MethodPut:
|
|
putCalled.Store(true)
|
|
w.WriteHeader(http.StatusCreated)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify HEAD check was used
|
|
if !headChecked.Load() {
|
|
t.Error("HEAD check was never made")
|
|
}
|
|
|
|
// Should not have attempted PUT (blob already exists)
|
|
if putCalled.Load() {
|
|
t.Error("PUT was called even though blob exists (HEAD returned 200)")
|
|
}
|
|
|
|
t.Log("HEAD-based existence check verified")
|
|
}
|
|
|
|
// TestUploadWithCustomRepository verifies that custom repository paths are used
|
|
func TestUploadWithCustomRepository(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob1, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var headPath, postPath string
|
|
var mu sync.Mutex
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
mu.Lock()
|
|
headPath = r.URL.Path
|
|
mu.Unlock()
|
|
w.WriteHeader(http.StatusNotFound) // Blob doesn't exist
|
|
case http.MethodPost:
|
|
mu.Lock()
|
|
postPath = r.URL.Path
|
|
mu.Unlock()
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/myorg/mymodel/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPatch:
|
|
session.recordPatch(w, r)
|
|
case http.MethodPut:
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob1},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Repository: "myorg/mymodel", // Custom repository
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify HEAD used custom repository path
|
|
expectedHeadPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob1.Digest)
|
|
if headPath != expectedHeadPath {
|
|
t.Errorf("HEAD path mismatch: got %s, want %s", headPath, expectedHeadPath)
|
|
}
|
|
|
|
// Verify POST used custom repository path
|
|
expectedPostPath := "/v2/myorg/mymodel/blobs/uploads/"
|
|
if postPath != expectedPostPath {
|
|
t.Errorf("POST path mismatch: got %s, want %s", postPath, expectedPostPath)
|
|
}
|
|
|
|
t.Logf("Custom repository paths verified: HEAD=%s, POST=%s", headPath, postPath)
|
|
}
|
|
|
|
// TestDownloadWithCustomRepository verifies that custom repository paths are used
|
|
func TestDownloadWithCustomRepository(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, data := createTestBlob(t, serverDir, 1024)
|
|
|
|
var requestPath string
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
requestPath = r.URL.Path
|
|
mu.Unlock()
|
|
|
|
// Serve blob from any path
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
blobData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(blobData)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Repository: "myorg/mymodel", // Custom repository
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob, data)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify request used custom repository path
|
|
expectedPath := fmt.Sprintf("/v2/myorg/mymodel/blobs/%s", blob.Digest)
|
|
if requestPath != expectedPath {
|
|
t.Errorf("Request path mismatch: got %s, want %s", requestPath, expectedPath)
|
|
}
|
|
|
|
t.Logf("Custom repository path verified: %s", requestPath)
|
|
}
|
|
|
|
func TestDigestToPath(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want string
|
|
}{
|
|
{"sha256:abc123", "sha256-abc123"},
|
|
{"sha256-abc123", "sha256-abc123"},
|
|
{"other", "other"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := digestToPath(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("digestToPath(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseAuthChallenge(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want AuthChallenge
|
|
}{
|
|
{
|
|
input: `Bearer realm="https://auth.example.com/token",service="registry",scope="repository:library/test:pull"`,
|
|
want: AuthChallenge{
|
|
Realm: "https://auth.example.com/token",
|
|
Service: "registry",
|
|
Scope: "repository:library/test:pull",
|
|
},
|
|
},
|
|
{
|
|
input: `Bearer realm="https://auth.example.com"`,
|
|
want: AuthChallenge{
|
|
Realm: "https://auth.example.com",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := parseAuthChallenge(tt.input)
|
|
if got.Realm != tt.want.Realm {
|
|
t.Errorf("parseAuthChallenge(%q).Realm = %q, want %q", tt.input, got.Realm, tt.want.Realm)
|
|
}
|
|
if got.Service != tt.want.Service {
|
|
t.Errorf("parseAuthChallenge(%q).Service = %q, want %q", tt.input, got.Service, tt.want.Service)
|
|
}
|
|
if got.Scope != tt.want.Scope {
|
|
t.Errorf("parseAuthChallenge(%q).Scope = %q, want %q", tt.input, got.Scope, tt.want.Scope)
|
|
}
|
|
}
|
|
}
|
|
|
|
func verifyBlob(t *testing.T, dir string, blob Blob, expected []byte) {
|
|
t.Helper()
|
|
|
|
path := filepath.Join(dir, digestToPath(blob.Digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Errorf("Failed to read %s: %v", blob.Digest[:19], err)
|
|
return
|
|
}
|
|
|
|
if len(data) != len(expected) {
|
|
t.Errorf("Size mismatch for %s: got %d, want %d", blob.Digest[:19], len(data), len(expected))
|
|
return
|
|
}
|
|
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
if digest != blob.Digest {
|
|
t.Errorf("Digest mismatch for %s: got %s", blob.Digest[:19], digest[:19])
|
|
}
|
|
}
|
|
|
|
// ==================== Parallelism Tests ====================
|
|
|
|
func TestDownloadParallelism(t *testing.T) {
|
|
// Create many blobs to test parallelism
|
|
serverDir := t.TempDir()
|
|
numBlobs := 10
|
|
blobs := make([]Blob, numBlobs)
|
|
blobData := make([][]byte, numBlobs)
|
|
|
|
for i := range numBlobs {
|
|
blobs[i], blobData[i] = createTestBlob(t, serverDir, 1024+i*100)
|
|
}
|
|
|
|
var activeRequests atomic.Int32
|
|
var maxConcurrent atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
current := activeRequests.Add(1)
|
|
defer activeRequests.Add(-1)
|
|
|
|
// Track max concurrent requests
|
|
for {
|
|
old := maxConcurrent.Load()
|
|
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Simulate network latency to ensure parallelism is visible
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 4,
|
|
BodyConcurrency: 4,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify all blobs downloaded
|
|
for i, blob := range blobs {
|
|
verifyBlob(t, clientDir, blob, blobData[i])
|
|
}
|
|
|
|
// Verify parallelism was used
|
|
if maxConcurrent.Load() < 2 {
|
|
t.Errorf("Max concurrent requests was %d, expected at least 2 for parallelism", maxConcurrent.Load())
|
|
}
|
|
|
|
// With 10 blobs at 50ms each, sequential would take ~500ms
|
|
// Parallel with 4 workers should take ~150ms (relax to 1s for CI variance)
|
|
if elapsed > time.Second {
|
|
t.Errorf("Downloads took %v, expected faster with parallelism", elapsed)
|
|
}
|
|
|
|
t.Logf("Downloaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
|
|
}
|
|
|
|
func TestUploadParallelism(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
numBlobs := 10
|
|
blobs := make([]Blob, numBlobs)
|
|
|
|
for i := range numBlobs {
|
|
blobs[i], _ = createTestBlob(t, clientDir, 1024+i*100)
|
|
}
|
|
|
|
var activeRequests atomic.Int32
|
|
var maxConcurrent atomic.Int32
|
|
var uploadedBlobs sync.Map
|
|
var uploadID atomic.Int32
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
current := activeRequests.Add(1)
|
|
defer activeRequests.Add(-1)
|
|
|
|
// Track max concurrent
|
|
for {
|
|
old := maxConcurrent.Load()
|
|
if current <= old || maxConcurrent.CompareAndSwap(old, current) {
|
|
break
|
|
}
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
id := uploadID.Add(1)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
time.Sleep(50 * time.Millisecond) // simulate upload time on body chunk
|
|
session.recordPatch(w, r)
|
|
|
|
case http.MethodPut:
|
|
digest := r.URL.Query().Get("digest")
|
|
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
start := time.Now()
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: blobs,
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Concurrency: 4,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify all blobs uploaded
|
|
for _, blob := range blobs {
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Errorf("Blob %s not uploaded", blob.Digest[:19])
|
|
}
|
|
}
|
|
|
|
if maxConcurrent.Load() < 2 {
|
|
t.Errorf("Max concurrent requests was %d, expected at least 2", maxConcurrent.Load())
|
|
}
|
|
|
|
t.Logf("Uploaded %d blobs in %v with max %d concurrent requests", numBlobs, elapsed, maxConcurrent.Load())
|
|
}
|
|
|
|
// ==================== Stall Detection Test ====================
|
|
|
|
func TestDownloadStallDetection(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping stall detection test in short mode")
|
|
}
|
|
|
|
serverDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, serverDir, 10*1024) // 10KB
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
count := requestCount.Add(1)
|
|
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
if count == 1 {
|
|
// First request: send partial data then stall
|
|
w.Write(data[:1024]) // Send first 1KB
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
// Stall for longer than stall timeout (test uses 200ms)
|
|
time.Sleep(500 * time.Millisecond)
|
|
return
|
|
}
|
|
|
|
// Subsequent requests: send full data
|
|
w.Write(data)
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
start := time.Now()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
StallTimeout: 200 * time.Millisecond, // Short timeout for testing
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Should have retried after stall detection
|
|
if requestCount.Load() < 2 {
|
|
t.Errorf("Expected at least 2 requests (stall + retry), got %d", requestCount.Load())
|
|
}
|
|
|
|
// Should complete quickly with short stall timeout
|
|
if elapsed > 3*time.Second {
|
|
t.Errorf("Download took %v, stall detection should have triggered faster", elapsed)
|
|
}
|
|
|
|
t.Logf("Stall detection worked: %d requests in %v", requestCount.Load(), elapsed)
|
|
}
|
|
|
|
// ==================== Context Cancellation Tests ====================
|
|
|
|
func TestDownloadCancellation(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, serverDir, 100*1024) // 100KB (smaller for faster test)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, _ := os.ReadFile(path)
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Send data slowly
|
|
for i := 0; i < len(data); i += 1024 {
|
|
end := i + 1024
|
|
if end > len(data) {
|
|
end = len(data)
|
|
}
|
|
w.Write(data[i:end])
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Cancel after 50ms
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
|
|
start := time.Now()
|
|
err := Download(ctx, DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
t.Fatal("Expected error from cancellation")
|
|
}
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Errorf("Expected context.Canceled error, got: %v", err)
|
|
}
|
|
|
|
// Should cancel quickly, not wait for full download
|
|
if elapsed > 500*time.Millisecond {
|
|
t.Errorf("Cancellation took %v, expected faster response", elapsed)
|
|
}
|
|
|
|
t.Logf("Cancellation worked in %v", elapsed)
|
|
}
|
|
|
|
func TestUploadCancellation(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 100*1024) // 100KB (smaller for faster test)
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
// Read slowly so the cancellation has time to interrupt the body upload.
|
|
buf := make([]byte, 1024)
|
|
for {
|
|
_, err := r.Body.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
w.Header().Set("Docker-Upload-Location", r.URL.Path)
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPut:
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer func() {
|
|
server.CloseClientConnections()
|
|
server.Close()
|
|
}()
|
|
serverURL = server.URL
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
|
|
start := time.Now()
|
|
err := Upload(ctx, UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
t.Fatal("Expected error from cancellation")
|
|
}
|
|
|
|
if elapsed > 500*time.Millisecond {
|
|
t.Errorf("Cancellation took %v, expected faster", elapsed)
|
|
}
|
|
|
|
t.Logf("Upload cancellation worked in %v", elapsed)
|
|
}
|
|
|
|
// ==================== Progress Tracking Tests ====================
|
|
|
|
func TestProgressTracking(t *testing.T) {
|
|
serverDir := t.TempDir()
|
|
blob1, data1 := createTestBlob(t, serverDir, 5000)
|
|
blob2, data2 := createTestBlob(t, serverDir, 3000)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
digest := filepath.Base(r.URL.Path)
|
|
path := filepath.Join(serverDir, digestToPath(digest))
|
|
data, _ := os.ReadFile(path)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
var progressHistory []struct{ completed, total int64 }
|
|
var mu sync.Mutex
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob1, blob2},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1, // Sequential to make progress predictable
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
progressHistory = append(progressHistory, struct{ completed, total int64 }{completed, total})
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
verifyBlob(t, clientDir, blob1, data1)
|
|
verifyBlob(t, clientDir, blob2, data2)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if len(progressHistory) == 0 {
|
|
t.Fatal("No progress callbacks received")
|
|
}
|
|
|
|
// Total should always be sum of blob sizes
|
|
expectedTotal := blob1.Size + blob2.Size
|
|
for _, p := range progressHistory {
|
|
if p.total != expectedTotal {
|
|
t.Errorf("Total changed during download: got %d, want %d", p.total, expectedTotal)
|
|
}
|
|
}
|
|
|
|
// Completed should be monotonically increasing
|
|
var lastCompleted int64
|
|
for _, p := range progressHistory {
|
|
if p.completed < lastCompleted {
|
|
t.Errorf("Progress went backwards: %d -> %d", lastCompleted, p.completed)
|
|
}
|
|
lastCompleted = p.completed
|
|
}
|
|
|
|
// Final completed should equal total
|
|
final := progressHistory[len(progressHistory)-1]
|
|
if final.completed != expectedTotal {
|
|
t.Errorf("Final completed %d != total %d", final.completed, expectedTotal)
|
|
}
|
|
|
|
t.Logf("Progress tracked correctly: %d callbacks, final %d/%d", len(progressHistory), final.completed, final.total)
|
|
}
|
|
|
|
// ==================== Edge Cases ====================
|
|
|
|
func TestDownloadEmptyBlobList(t *testing.T) {
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{},
|
|
BaseURL: "http://unused",
|
|
DestDir: t.TempDir(),
|
|
})
|
|
if err != nil {
|
|
t.Errorf("Expected no error for empty blob list, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUploadEmptyBlobList(t *testing.T) {
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{},
|
|
BaseURL: "http://unused",
|
|
SrcDir: t.TempDir(),
|
|
})
|
|
if err != nil {
|
|
t.Errorf("Expected no error for empty blob list, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUploadRetryOnFailure(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var patchCount atomic.Int32
|
|
var uploadedBlobs sync.Map
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
count := patchCount.Add(1)
|
|
if count < 3 {
|
|
// Fail first 2 PATCH attempts to exercise the retry path
|
|
io.Copy(io.Discard, r.Body)
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
session.recordPatch(w, r)
|
|
|
|
case http.MethodPut:
|
|
digest := r.URL.Query().Get("digest")
|
|
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with retry failed: %v", err)
|
|
}
|
|
|
|
if _, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("Blob not uploaded after retry")
|
|
}
|
|
|
|
if patchCount.Load() < 3 {
|
|
t.Errorf("Expected at least 3 PATCH attempts, got %d", patchCount.Load())
|
|
}
|
|
}
|
|
|
|
// TestProgressRollback verifies that progress is rolled back on retry
|
|
func TestProgressRollback(t *testing.T) {
|
|
content := []byte("test content for rollback test")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, content, 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var patchCount atomic.Int32
|
|
var progressValues []int64
|
|
var mu sync.Mutex
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case http.MethodPatch:
|
|
// Read some bytes (so the client reports progress) before failing,
|
|
// to exercise the rollback-on-retry path.
|
|
count := patchCount.Add(1)
|
|
if count < 3 {
|
|
io.CopyN(io.Discard, r.Body, 10)
|
|
io.Copy(io.Discard, r.Body)
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
session.recordPatch(w, r)
|
|
|
|
case http.MethodPut:
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Progress: func(completed, total int64) {
|
|
mu.Lock()
|
|
progressValues = append(progressValues, completed)
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with retry failed: %v", err)
|
|
}
|
|
|
|
// Check that progress was rolled back (should have negative values or drops)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Final progress should equal blob size
|
|
if len(progressValues) > 0 {
|
|
final := progressValues[len(progressValues)-1]
|
|
if final != blob.Size {
|
|
t.Errorf("Final progress %d != blob size %d", final, blob.Size)
|
|
}
|
|
}
|
|
|
|
t.Logf("Progress rollback test: %d progress callbacks", len(progressValues))
|
|
}
|
|
|
|
// TestUserAgentHeader verifies User-Agent header is set on requests
|
|
func TestUserAgentHeader(t *testing.T) {
|
|
content := []byte("test content")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
destDir := t.TempDir()
|
|
var userAgents []string
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
ua := r.Header.Get("User-Agent")
|
|
userAgents = append(userAgents, ua)
|
|
mu.Unlock()
|
|
|
|
if r.Method == http.MethodGet {
|
|
w.Write(content)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Test with custom User-Agent
|
|
customUA := "test-agent/1.0"
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: destDir,
|
|
UserAgent: customUA,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Verify custom User-Agent was used
|
|
for _, ua := range userAgents {
|
|
if ua != customUA {
|
|
t.Errorf("User-Agent %q != expected %q", ua, customUA)
|
|
}
|
|
}
|
|
t.Logf("User-Agent header test: %d requests with correct User-Agent", len(userAgents))
|
|
}
|
|
|
|
// TestDefaultUserAgent verifies default User-Agent is used when not specified
|
|
func TestDefaultUserAgent(t *testing.T) {
|
|
content := []byte("test content")
|
|
digest := fmt.Sprintf("sha256:%x", sha256.Sum256(content))
|
|
blob := Blob{Digest: digest, Size: int64(len(content))}
|
|
|
|
destDir := t.TempDir()
|
|
var userAgent string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userAgent = r.Header.Get("User-Agent")
|
|
if r.Method == http.MethodGet {
|
|
w.Write(content)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: destDir,
|
|
// No UserAgent specified - should use default
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
if userAgent == "" {
|
|
t.Error("User-Agent header was empty")
|
|
}
|
|
if userAgent != defaultUserAgent {
|
|
t.Errorf("Default User-Agent %q != expected %q", userAgent, defaultUserAgent)
|
|
}
|
|
}
|
|
|
|
// TestManifestPush verifies that manifest is pushed after blobs
|
|
func TestManifestPush(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1000)
|
|
|
|
testManifest := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
|
|
testRepo := "library/test-model"
|
|
testRef := "latest"
|
|
|
|
var manifestReceived []byte
|
|
var manifestPath string
|
|
var manifestContentType string
|
|
var serverURL string
|
|
|
|
session := newChunkedSession()
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads"):
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPatch:
|
|
session.recordPatch(w, r)
|
|
|
|
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/"):
|
|
// Finalize the chunked blob upload
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/"):
|
|
manifestPath = r.URL.Path
|
|
manifestContentType = r.Header.Get("Content-Type")
|
|
manifestReceived, _ = io.ReadAll(r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Manifest: testManifest,
|
|
ManifestRef: testRef,
|
|
Repository: testRepo,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
// Verify manifest was pushed
|
|
if manifestReceived == nil {
|
|
t.Fatal("Manifest was not received by server")
|
|
}
|
|
|
|
if !bytes.Equal(manifestReceived, testManifest) {
|
|
t.Errorf("Manifest content mismatch: got %s, want %s", manifestReceived, testManifest)
|
|
}
|
|
|
|
expectedPath := fmt.Sprintf("/v2/%s/manifests/%s", testRepo, testRef)
|
|
if manifestPath != expectedPath {
|
|
t.Errorf("Manifest path mismatch: got %s, want %s", manifestPath, expectedPath)
|
|
}
|
|
|
|
if manifestContentType != "application/vnd.docker.distribution.manifest.v2+json" {
|
|
t.Errorf("Manifest content type mismatch: got %s", manifestContentType)
|
|
}
|
|
|
|
t.Logf("Manifest push test passed: received %d bytes at %s", len(manifestReceived), manifestPath)
|
|
}
|
|
|
|
// ==================== Throughput Benchmarks ====================
|
|
|
|
func BenchmarkDownloadThroughput(b *testing.B) {
|
|
// Create test data - 1MB blob
|
|
data := make([]byte, 1024*1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(len(data))}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
b.SetBytes(int64(len(data)))
|
|
b.ResetTimer()
|
|
|
|
for range b.N {
|
|
clientDir := b.TempDir()
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
Concurrency: 1,
|
|
})
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkUploadThroughput(b *testing.B) {
|
|
// Create test data - 1MB blob
|
|
data := make([]byte, 1024*1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(len(data))}
|
|
|
|
// Create source file once
|
|
srcDir := b.TempDir()
|
|
path := filepath.Join(srcDir, digestToPath(digest))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
b.SetBytes(int64(len(data)))
|
|
b.ResetTimer()
|
|
|
|
for range b.N {
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: srcDir,
|
|
Concurrency: 1,
|
|
})
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ==================== Resume Tests ====================
|
|
|
|
func TestResumeFromPartialFile(t *testing.T) {
|
|
// Create a blob large enough for resume (>= resumeThreshold)
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
var rangeHeader string
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
mu.Lock()
|
|
rangeHeader = r.Header.Get("Range")
|
|
mu.Unlock()
|
|
|
|
rng := r.Header.Get("Range")
|
|
if rng != "" {
|
|
// Parse "bytes=N-"
|
|
var start int64
|
|
fmt.Sscanf(rng, "bytes=%d-", &start)
|
|
if start > 0 && start < int64(blobSize) {
|
|
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
w.Write(data[start:])
|
|
return
|
|
}
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create a partial .tmp file (first half)
|
|
partialSize := blobSize / 2
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", data[:partialSize], 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Resume download failed: %v", err)
|
|
}
|
|
|
|
// Verify Range header was sent
|
|
mu.Lock()
|
|
if rangeHeader == "" {
|
|
t.Error("Expected Range header for resume, got none")
|
|
} else {
|
|
expected := fmt.Sprintf("bytes=%d-", partialSize)
|
|
if rangeHeader != expected {
|
|
t.Errorf("Range header = %q, want %q", rangeHeader, expected)
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
|
|
// Verify final file is correct
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
if len(finalData) != blobSize {
|
|
t.Errorf("Final file size = %d, want %d", len(finalData), blobSize)
|
|
}
|
|
finalHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
|
t.Error("Final file hash mismatch")
|
|
}
|
|
}
|
|
|
|
func TestResumeCorruptPartialFile(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
rng := r.Header.Get("Range")
|
|
if rng != "" {
|
|
var start int64
|
|
fmt.Sscanf(rng, "bytes=%d-", &start)
|
|
if start > 0 && start < int64(blobSize) {
|
|
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
w.Write(data[start:])
|
|
return
|
|
}
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create a partial .tmp file with CORRUPT data
|
|
partialSize := blobSize / 2
|
|
corruptData := make([]byte, partialSize)
|
|
for i := range corruptData {
|
|
corruptData[i] = 0xFF // All 0xFF — definitely wrong
|
|
}
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", corruptData, 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
// First attempt resumes with corrupt data → hash mismatch → retry.
|
|
// Retry should clean up .tmp and re-download fully.
|
|
if err != nil {
|
|
t.Fatalf("Download with corrupt partial file failed: %v", err)
|
|
}
|
|
|
|
// Verify final file is correct
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
finalHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
|
t.Error("Final file hash mismatch after corrupt resume recovery")
|
|
}
|
|
}
|
|
|
|
func TestResumePartialFileLargerThanBlob(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create .tmp file LARGER than expected blob
|
|
oversizedData := make([]byte, blobSize+1000)
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", oversizedData, 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download with oversized .tmp failed: %v", err)
|
|
}
|
|
|
|
// Verify final file is correct
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
finalHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
|
t.Error("Final file hash mismatch")
|
|
}
|
|
}
|
|
|
|
func TestResumeBelowThreshold(t *testing.T) {
|
|
// Blob below resume threshold should NOT attempt resume
|
|
blobSize := 1024 // Well below resumeThreshold
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
var gotRange atomic.Bool
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
if r.Header.Get("Range") != "" {
|
|
gotRange.Store(true)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create a partial .tmp file
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
if gotRange.Load() {
|
|
t.Error("Range header sent for blob below resume threshold — should not attempt resume")
|
|
}
|
|
|
|
// Verify final file
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
finalHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
|
t.Error("Final file hash mismatch")
|
|
}
|
|
}
|
|
|
|
func TestResumeServerDoesNotSupportRange(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
// Ignore Range header — always return full content with 200
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create partial .tmp file
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", data[:blobSize/2], 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed when server doesn't support Range: %v", err)
|
|
}
|
|
|
|
// Verify final file is correct
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
finalHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", finalHash) != digest {
|
|
t.Error("Final file hash mismatch")
|
|
}
|
|
}
|
|
|
|
func TestResumePartialFileExactSize(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead {
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
requestCount.Add(1)
|
|
|
|
rng := r.Header.Get("Range")
|
|
if rng != "" {
|
|
var start int64
|
|
fmt.Sscanf(rng, "bytes=%d-", &start)
|
|
if start >= int64(blobSize) {
|
|
// Nothing to send
|
|
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
|
return
|
|
}
|
|
if start > 0 {
|
|
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, blobSize-1, blobSize))
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
w.Write(data[start:])
|
|
return
|
|
}
|
|
}
|
|
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", blobSize))
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(data)
|
|
}))
|
|
defer server.Close()
|
|
|
|
clientDir := t.TempDir()
|
|
|
|
// Pre-create .tmp file with exact correct content (full size)
|
|
// This simulates a download that completed but wasn't renamed
|
|
dest := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
os.WriteFile(dest+".tmp", data, 0o644)
|
|
|
|
err := Download(context.Background(), DownloadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
DestDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Download failed: %v", err)
|
|
}
|
|
|
|
// Verify final file is correct
|
|
finalData, err := os.ReadFile(dest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read final file: %v", err)
|
|
}
|
|
resumeHash := sha256.Sum256(finalData)
|
|
if fmt.Sprintf("sha256:%x", resumeHash) != digest {
|
|
t.Error("Final file hash mismatch")
|
|
}
|
|
}
|
|
|
|
// ==================== Chunked Upload Tests ====================
|
|
|
|
// chunkedUploadServer creates a test server that implements the OCI chunked
|
|
// upload protocol: POST → PATCH* (with Content-Range) → PUT (finalize).
|
|
type chunkedUploadServer struct {
|
|
t *testing.T
|
|
mu sync.Mutex
|
|
parts map[int][]byte // part offset -> received data
|
|
patchCount int
|
|
finalized bool
|
|
finalDigest string
|
|
finalEtag string
|
|
patchHandler func(w http.ResponseWriter, r *http.Request) // optional override
|
|
uploadCounter int
|
|
serverURL *string
|
|
}
|
|
|
|
func newChunkedUploadServer(t *testing.T) *chunkedUploadServer {
|
|
return &chunkedUploadServer{
|
|
t: t,
|
|
parts: make(map[int][]byte),
|
|
}
|
|
}
|
|
|
|
func (s *chunkedUploadServer) handler() http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
|
|
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/uploads"):
|
|
s.mu.Lock()
|
|
s.uploadCounter++
|
|
id := s.uploadCounter
|
|
s.mu.Unlock()
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/%d", *s.serverURL, id))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
|
|
case r.Method == http.MethodPatch:
|
|
if s.patchHandler != nil {
|
|
s.patchHandler(w, r)
|
|
return
|
|
}
|
|
s.defaultPatchHandler(w, r)
|
|
|
|
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/uploads"):
|
|
s.mu.Lock()
|
|
s.finalized = true
|
|
s.finalDigest = r.URL.Query().Get("digest")
|
|
s.finalEtag = r.URL.Query().Get("etag")
|
|
s.mu.Unlock()
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *chunkedUploadServer) defaultPatchHandler(w http.ResponseWriter, r *http.Request) {
|
|
s.mu.Lock()
|
|
s.patchCount++
|
|
patchNum := s.patchCount
|
|
s.mu.Unlock()
|
|
|
|
cr := r.Header.Get("Content-Range")
|
|
if cr == "" {
|
|
http.Error(w, "missing Content-Range", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var start, end int64
|
|
fmt.Sscanf(cr, "%d-%d", &start, &end)
|
|
|
|
data, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.parts[int(start)] = data
|
|
s.mu.Unlock()
|
|
|
|
w.Header().Set("Docker-Upload-Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/session-%d", *s.serverURL, patchNum+1))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
}
|
|
|
|
func (s *chunkedUploadServer) reassemble(totalSize int) []byte {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
result := make([]byte, totalSize)
|
|
for offset, data := range s.parts {
|
|
copy(result[offset:], data)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func TestComputeParts(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
totalSize int64
|
|
wantParts int
|
|
wantFirst int64
|
|
}{
|
|
{
|
|
name: "1GB blob — clamped to min part size",
|
|
totalSize: 1 << 30,
|
|
wantParts: int((1<<30 + minUploadPartSize - 1) / minUploadPartSize),
|
|
wantFirst: minUploadPartSize,
|
|
},
|
|
{
|
|
name: "5GB blob — 16 parts",
|
|
totalSize: 5 << 30,
|
|
wantParts: 16,
|
|
wantFirst: 5 << 30 / 16,
|
|
},
|
|
{
|
|
name: "20GB blob — clamped to max part size",
|
|
totalSize: 20 << 30,
|
|
wantParts: int((20<<30 + maxUploadPartSize - 1) / maxUploadPartSize),
|
|
wantFirst: maxUploadPartSize,
|
|
},
|
|
{
|
|
name: "exactly min part size",
|
|
totalSize: minUploadPartSize,
|
|
wantParts: 1,
|
|
wantFirst: minUploadPartSize,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parts := computeParts(tt.totalSize)
|
|
if len(parts) != tt.wantParts {
|
|
t.Errorf("computeParts(%d) = %d parts, want %d", tt.totalSize, len(parts), tt.wantParts)
|
|
}
|
|
if len(parts) > 0 && parts[0].size != tt.wantFirst {
|
|
t.Errorf("first part size = %d, want %d", parts[0].size, tt.wantFirst)
|
|
}
|
|
|
|
// Verify parts cover entire blob with no gaps
|
|
var total int64
|
|
for i, p := range parts {
|
|
if p.offset != total {
|
|
t.Errorf("part %d offset = %d, want %d", i, p.offset, total)
|
|
}
|
|
if p.n != i {
|
|
t.Errorf("part %d n = %d, want %d", i, p.n, i)
|
|
}
|
|
total += p.size
|
|
}
|
|
if total != tt.totalSize {
|
|
t.Errorf("total part sizes = %d, want %d", total, tt.totalSize)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChunkedUploadBasic(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 7) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
|
os.WriteFile(path, data, 0o644)
|
|
|
|
srv := newChunkedUploadServer(t)
|
|
var serverURL string
|
|
srv.serverURL = &serverURL
|
|
server := httptest.NewServer(srv.handler())
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
var progressCalls atomic.Int32
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
Progress: func(completed, total int64) {
|
|
progressCalls.Add(1)
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Chunked upload failed: %v", err)
|
|
}
|
|
|
|
reassembled := srv.reassemble(blobSize)
|
|
reassembledHash := sha256.Sum256(reassembled)
|
|
if fmt.Sprintf("sha256:%x", reassembledHash) != digest {
|
|
t.Error("Reassembled data hash mismatch")
|
|
}
|
|
|
|
srv.mu.Lock()
|
|
if !srv.finalized {
|
|
t.Error("Finalize PUT was not called")
|
|
}
|
|
if srv.finalDigest != digest {
|
|
t.Errorf("Finalize digest = %s, want %s", srv.finalDigest, digest)
|
|
}
|
|
if srv.finalEtag == "" {
|
|
t.Error("Finalize etag is empty")
|
|
}
|
|
if srv.patchCount == 0 {
|
|
t.Error("No PATCH requests were sent")
|
|
}
|
|
srv.mu.Unlock()
|
|
|
|
if progressCalls.Load() == 0 {
|
|
t.Error("Progress callback never called")
|
|
}
|
|
}
|
|
|
|
// TestSmallBlobUsesChunkedFlow verifies that even small blobs go through the
|
|
// PATCH+finalize chunked flow. Server-side redirect logic is gated on PATCH,
|
|
// so a single-PUT path could never trigger CDN redirection — every blob must
|
|
// use PATCH so the server has the chance to redirect.
|
|
func TestChunkedUploadCDNRedirect(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i * 7) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
|
os.WriteFile(path, data, 0o644)
|
|
|
|
cdnParts := make(map[string][]byte)
|
|
var cdnMu sync.Mutex
|
|
var cdnGotAuth atomic.Bool
|
|
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Authorization") != "" {
|
|
cdnGotAuth.Store(true)
|
|
}
|
|
cdnData, _ := io.ReadAll(r.Body)
|
|
cdnMu.Lock()
|
|
cdnParts[r.URL.Path] = cdnData
|
|
cdnMu.Unlock()
|
|
w.WriteHeader(http.StatusCreated)
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
var serverURL string
|
|
var patchCount atomic.Int32
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case r.Method == http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/uploads"):
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case r.Method == http.MethodPatch:
|
|
n := patchCount.Add(1)
|
|
w.Header().Set("Docker-Upload-Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/session-%d", serverURL, n+1))
|
|
cdnPath := fmt.Sprintf("/cdn/part-%d", n)
|
|
http.Redirect(w, r, cdn.URL+cdnPath, http.StatusTemporaryRedirect)
|
|
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/uploads"):
|
|
w.WriteHeader(http.StatusCreated)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with CDN redirect failed: %v", err)
|
|
}
|
|
|
|
cdnMu.Lock()
|
|
totalCDNBytes := 0
|
|
for _, d := range cdnParts {
|
|
totalCDNBytes += len(d)
|
|
}
|
|
cdnMu.Unlock()
|
|
|
|
if totalCDNBytes != blobSize {
|
|
t.Errorf("CDN received %d bytes, want %d", totalCDNBytes, blobSize)
|
|
}
|
|
|
|
if cdnGotAuth.Load() {
|
|
t.Error("CDN received Authorization header — should not be sent to CDN")
|
|
}
|
|
}
|
|
|
|
func TestChunkedUploadPartRetry(t *testing.T) {
|
|
blobSize := resumeThreshold + 1024
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
|
os.WriteFile(path, data, 0o644)
|
|
|
|
var patchAttempts atomic.Int32
|
|
|
|
srv := newChunkedUploadServer(t)
|
|
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
|
|
attempt := patchAttempts.Add(1)
|
|
if attempt == 1 {
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
srv.defaultPatchHandler(w, r)
|
|
}
|
|
var serverURL string
|
|
srv.serverURL = &serverURL
|
|
server := httptest.NewServer(srv.handler())
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload with retry failed: %v", err)
|
|
}
|
|
|
|
if patchAttempts.Load() < 2 {
|
|
t.Errorf("Expected at least 2 PATCH attempts, got %d", patchAttempts.Load())
|
|
}
|
|
|
|
reassembled := srv.reassemble(blobSize)
|
|
reassembledHash := sha256.Sum256(reassembled)
|
|
if fmt.Sprintf("sha256:%x", reassembledHash) != digest {
|
|
t.Error("Data integrity failed after retry")
|
|
}
|
|
}
|
|
|
|
// multiPartTestHelper creates an uploader with small part sizes so multi-part
|
|
// behavior can be tested with small blobs. Returns the uploader and blob data.
|
|
func multiPartTestHelper(t *testing.T, blobSize int, partSize int64, serverURL string) (*uploader, Blob, []byte) {
|
|
t.Helper()
|
|
data := make([]byte, blobSize)
|
|
for i := range data {
|
|
data[i] = byte((i*7 + 13) % 256)
|
|
}
|
|
h := sha256.Sum256(data)
|
|
digest := fmt.Sprintf("sha256:%x", h)
|
|
blob := Blob{Digest: digest, Size: int64(blobSize)}
|
|
|
|
clientDir := t.TempDir()
|
|
path := filepath.Join(clientDir, digestToPath(digest))
|
|
os.MkdirAll(filepath.Dir(path), 0o755)
|
|
os.WriteFile(path, data, 0o644)
|
|
|
|
u := &uploader{
|
|
client: defaultClient,
|
|
baseURL: serverURL,
|
|
srcDir: clientDir,
|
|
userAgent: defaultUserAgent,
|
|
progress: newProgressTracker(int64(blobSize), nil),
|
|
makeParts: func(totalSize int64) []uploadPart {
|
|
return computePartsWithLimits(totalSize, 16, partSize, partSize*10)
|
|
},
|
|
}
|
|
return u, blob, data
|
|
}
|
|
|
|
func TestChunkedUploadMultiPartSessionURLChain(t *testing.T) {
|
|
// Use 10KB blobs with 2KB parts → 5 parts, exercising the URL chain
|
|
blobSize := 10240
|
|
partSize := int64(2048)
|
|
|
|
var patchURLs []string
|
|
var mu sync.Mutex
|
|
|
|
srv := newChunkedUploadServer(t)
|
|
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
patchURLs = append(patchURLs, r.URL.Path)
|
|
mu.Unlock()
|
|
srv.defaultPatchHandler(w, r)
|
|
}
|
|
var serverURL string
|
|
srv.serverURL = &serverURL
|
|
server := httptest.NewServer(srv.handler())
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
|
|
|
|
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/init-1", server.URL)
|
|
n, err := u.putChunked(context.Background(), initURL, f, blob)
|
|
if err != nil {
|
|
t.Fatalf("putChunked failed: %v", err)
|
|
}
|
|
if n != int64(blobSize) {
|
|
t.Errorf("bytes written = %d, want %d", n, blobSize)
|
|
}
|
|
|
|
// Verify data integrity
|
|
reassembled := srv.reassemble(blobSize)
|
|
if !bytes.Equal(reassembled, data) {
|
|
t.Error("Reassembled data mismatch")
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// Should have 5 parts with distinct URLs
|
|
if len(patchURLs) != 5 {
|
|
t.Fatalf("Expected 5 PATCH requests, got %d", len(patchURLs))
|
|
}
|
|
|
|
// First PATCH uses the init URL
|
|
if !strings.Contains(patchURLs[0], "init-1") {
|
|
t.Errorf("First PATCH URL should contain init-1, got %s", patchURLs[0])
|
|
}
|
|
|
|
// Subsequent PATCHes should use session URLs from Docker-Upload-Location
|
|
for i := 1; i < len(patchURLs); i++ {
|
|
if patchURLs[i] == patchURLs[i-1] {
|
|
t.Errorf("PATCH %d used same URL as PATCH %d — chain broken", i, i-1)
|
|
}
|
|
if !strings.Contains(patchURLs[i], "session-") {
|
|
t.Errorf("PATCH %d URL should contain session-, got %s", i, patchURLs[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestChunkedUploadMultiPartDataIntegrity(t *testing.T) {
|
|
// Non-evenly-divisible: 10001 bytes with 3000-byte parts → 4 parts (3000+3000+3000+1001)
|
|
blobSize := 10001
|
|
partSize := int64(3000)
|
|
|
|
srv := newChunkedUploadServer(t)
|
|
var serverURL string
|
|
srv.serverURL = &serverURL
|
|
server := httptest.NewServer(srv.handler())
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
|
|
|
|
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", server.URL)
|
|
_, err = u.putChunked(context.Background(), initURL, f, blob)
|
|
if err != nil {
|
|
t.Fatalf("putChunked failed: %v", err)
|
|
}
|
|
|
|
reassembled := srv.reassemble(blobSize)
|
|
if !bytes.Equal(reassembled, data) {
|
|
t.Error("Reassembled data mismatch with non-evenly-divisible parts")
|
|
}
|
|
|
|
srv.mu.Lock()
|
|
if srv.patchCount != 4 {
|
|
t.Errorf("Expected 4 PATCH requests, got %d", srv.patchCount)
|
|
}
|
|
srv.mu.Unlock()
|
|
}
|
|
|
|
func TestChunkedUploadMultiPartProgressRollback(t *testing.T) {
|
|
blobSize := 6000
|
|
partSize := int64(2000) // 3 parts
|
|
|
|
var patchAttempts atomic.Int32
|
|
|
|
srv := newChunkedUploadServer(t)
|
|
srv.patchHandler = func(w http.ResponseWriter, r *http.Request) {
|
|
attempt := patchAttempts.Add(1)
|
|
// Fail the second PATCH attempt (part 1, first try). Drain the body
|
|
// before erroring so the server sends 100 Continue (under Expect:
|
|
// 100-continue) and the client uploads the body — that's what makes
|
|
// the progress rollback observable.
|
|
if attempt == 2 {
|
|
io.Copy(io.Discard, r.Body)
|
|
http.Error(w, "server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
srv.defaultPatchHandler(w, r)
|
|
}
|
|
var serverURL string
|
|
srv.serverURL = &serverURL
|
|
server := httptest.NewServer(srv.handler())
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
u, blob, data := multiPartTestHelper(t, blobSize, partSize, server.URL)
|
|
// Track progress
|
|
var progressValues []int64
|
|
var mu sync.Mutex
|
|
u.progress = newProgressTracker(int64(blobSize), func(completed, total int64) {
|
|
mu.Lock()
|
|
progressValues = append(progressValues, completed)
|
|
mu.Unlock()
|
|
})
|
|
|
|
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
initURL := fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", server.URL)
|
|
_, err = u.putChunked(context.Background(), initURL, f, blob)
|
|
if err != nil {
|
|
t.Fatalf("putChunked failed: %v", err)
|
|
}
|
|
|
|
// Verify data integrity despite retry
|
|
reassembled := srv.reassemble(blobSize)
|
|
if !bytes.Equal(reassembled, data) {
|
|
t.Error("Data mismatch after retry")
|
|
}
|
|
|
|
// Verify progress had a rollback (decrease) then recovered
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
hadDecrease := false
|
|
for i := 1; i < len(progressValues); i++ {
|
|
if progressValues[i] < progressValues[i-1] {
|
|
hadDecrease = true
|
|
break
|
|
}
|
|
}
|
|
if !hadDecrease {
|
|
t.Error("Expected progress to decrease (rollback) during retry, but it was monotonic")
|
|
}
|
|
|
|
// Final should equal blob size
|
|
if len(progressValues) > 0 && progressValues[len(progressValues)-1] != int64(blobSize) {
|
|
t.Errorf("Final progress = %d, want %d", progressValues[len(progressValues)-1], blobSize)
|
|
}
|
|
}
|
|
|
|
// ==================== v2 direct-upload extension tests ====================
|
|
|
|
// TestV2InitRequestShape verifies the init POST advertises the v2 capability
|
|
// with the expected query parameter and headers, and that the request body
|
|
// is empty.
|
|
func TestV2InitRequestShape(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 4096)
|
|
|
|
var sawDigestQuery, sawCapHeader, sawSizeHeader string
|
|
var bodyLen int
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
sawDigestQuery = r.URL.Query().Get("digest")
|
|
sawCapHeader = r.Header.Get("X-Redirect-Uploads")
|
|
sawSizeHeader = r.Header.Get("X-Content-Length")
|
|
body, _ := io.ReadAll(r.Body)
|
|
bodyLen = len(body)
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPatch:
|
|
session.recordPatch(w, r)
|
|
case http.MethodPut:
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
if sawDigestQuery != blob.Digest {
|
|
t.Errorf("init POST ?digest= = %q, want %q", sawDigestQuery, blob.Digest)
|
|
}
|
|
if sawCapHeader != "2" {
|
|
t.Errorf("init POST X-Redirect-Uploads = %q, want %q", sawCapHeader, "2")
|
|
}
|
|
if sawSizeHeader != fmt.Sprintf("%d", blob.Size) {
|
|
t.Errorf("init POST X-Content-Length = %q, want %q", sawSizeHeader, fmt.Sprintf("%d", blob.Size))
|
|
}
|
|
if bodyLen != 0 {
|
|
t.Errorf("init POST body length = %d, want 0", bodyLen)
|
|
}
|
|
}
|
|
|
|
// TestV2DirectUpload verifies the v2 happy path: server returns
|
|
// X-Direct-Upload-URL + X-Signed-Header-X-Amz-Checksum-Sha256, the client
|
|
// PUTs body to the direct URL with the forwarded checksum header, then
|
|
// commits via a bodyless PUT to the session URL.
|
|
func TestV2DirectUpload(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, data := createTestBlob(t, clientDir, 8192)
|
|
|
|
var (
|
|
cdnReceived []byte
|
|
cdnChecksumHeader string
|
|
cdnAuthHeader string
|
|
cdnHits atomic.Int32
|
|
)
|
|
|
|
cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cdnHits.Add(1)
|
|
if r.Method != http.MethodPut {
|
|
http.Error(w, "want PUT", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
cdnChecksumHeader = r.Header.Get("X-Amz-Checksum-Sha256")
|
|
cdnAuthHeader = r.Header.Get("Authorization")
|
|
cdnReceived, _ = io.ReadAll(r.Body)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer cdn.Close()
|
|
|
|
// Compute the checksum the registry would sign — base64 of the SHA-256
|
|
// binary digest. The mock just hands the value back to the client; the
|
|
// client forwards it to the CDN.
|
|
hexDigest := strings.TrimPrefix(blob.Digest, "sha256:")
|
|
digestBytes := make([]byte, len(hexDigest)/2)
|
|
for i := range digestBytes {
|
|
fmt.Sscanf(hexDigest[i*2:i*2+2], "%02x", &digestBytes[i])
|
|
}
|
|
expectedChecksum := base64.StdEncoding.EncodeToString(digestBytes)
|
|
|
|
var (
|
|
commitDigest atomic.Value
|
|
commitBody atomic.Int32
|
|
)
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
directURL := cdn.URL + "/upload/" + blob.Digest
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.Header().Set("X-Direct-Upload-URL", directURL)
|
|
w.Header().Set("X-Signed-Header-X-Amz-Checksum-Sha256", expectedChecksum)
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
commitDigest.Store(r.URL.Query().Get("digest"))
|
|
body, _ := io.ReadAll(r.Body)
|
|
commitBody.Store(int32(len(body)))
|
|
w.WriteHeader(http.StatusCreated)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
if got := cdnHits.Load(); got != 1 {
|
|
t.Errorf("CDN hits = %d, want 1", got)
|
|
}
|
|
if !bytes.Equal(cdnReceived, data) {
|
|
t.Errorf("CDN body length = %d, want %d", len(cdnReceived), len(data))
|
|
}
|
|
if cdnChecksumHeader != expectedChecksum {
|
|
t.Errorf("CDN x-amz-checksum-sha256 = %q, want %q", cdnChecksumHeader, expectedChecksum)
|
|
}
|
|
if cdnAuthHeader != "" {
|
|
t.Errorf("CDN Authorization = %q, want empty (presigned URL shouldn't carry auth)", cdnAuthHeader)
|
|
}
|
|
if got, _ := commitDigest.Load().(string); got != blob.Digest {
|
|
t.Errorf("commit ?digest= = %q, want %q", got, blob.Digest)
|
|
}
|
|
if commitBody.Load() != 0 {
|
|
t.Errorf("commit body length = %d, want 0", commitBody.Load())
|
|
}
|
|
}
|
|
|
|
// TestV2FallbackToChunked verifies that when the server returns a standard
|
|
// 202 without v2 extension headers, the client falls back to the chunked
|
|
// PATCH path. This exercises the vanilla Docker Registry compatibility.
|
|
func TestV2FallbackToChunked(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, data := createTestBlob(t, clientDir, 8192)
|
|
|
|
var (
|
|
uploadedBlobs sync.Map
|
|
patchHit atomic.Int32
|
|
commitHit atomic.Int32
|
|
)
|
|
session := newChunkedSession()
|
|
|
|
var serverURL string
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
// vanilla: standard Location only, no v2 extension headers
|
|
w.Header().Set("Location", fmt.Sprintf("%s/v2/library/_/blobs/uploads/1", serverURL))
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPatch:
|
|
patchHit.Add(1)
|
|
session.recordPatch(w, r)
|
|
case http.MethodPut:
|
|
commitHit.Add(1)
|
|
digest := r.URL.Query().Get("digest")
|
|
uploadedBlobs.Store(digest, session.finalize(r.URL.Path))
|
|
w.WriteHeader(http.StatusCreated)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
serverURL = server.URL
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
if patchHit.Load() < 1 {
|
|
t.Error("expected at least one PATCH (chunked fallback), got none")
|
|
}
|
|
if commitHit.Load() != 1 {
|
|
t.Errorf("commit PUT hits = %d, want 1", commitHit.Load())
|
|
}
|
|
if got, ok := uploadedBlobs.Load(blob.Digest); !ok {
|
|
t.Error("blob not uploaded")
|
|
} else if !bytes.Equal(got.([]byte), data) {
|
|
t.Errorf("uploaded body length = %d, want %d", len(got.([]byte)), len(data))
|
|
}
|
|
}
|
|
|
|
// TestV2BlobAlreadyExists verifies that a 201 Created response from the init
|
|
// POST short-circuits the upload — the server has matched our ?digest=
|
|
// against existing storage and there's nothing to upload.
|
|
func TestV2BlobAlreadyExists(t *testing.T) {
|
|
clientDir := t.TempDir()
|
|
blob, _ := createTestBlob(t, clientDir, 1024)
|
|
|
|
var (
|
|
postHits atomic.Int32
|
|
patchHits atomic.Int32
|
|
putHits atomic.Int32
|
|
)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodHead:
|
|
http.NotFound(w, r)
|
|
case http.MethodPost:
|
|
postHits.Add(1)
|
|
// Server matched ?digest= against existing storage.
|
|
w.WriteHeader(http.StatusCreated)
|
|
case http.MethodPatch:
|
|
patchHits.Add(1)
|
|
w.WriteHeader(http.StatusAccepted)
|
|
case http.MethodPut:
|
|
putHits.Add(1)
|
|
w.WriteHeader(http.StatusCreated)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := Upload(context.Background(), UploadOptions{
|
|
Blobs: []Blob{blob},
|
|
BaseURL: server.URL,
|
|
SrcDir: clientDir,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Upload failed: %v", err)
|
|
}
|
|
|
|
if postHits.Load() != 1 {
|
|
t.Errorf("init POST hits = %d, want 1", postHits.Load())
|
|
}
|
|
if patchHits.Load() != 0 {
|
|
t.Errorf("PATCH hits = %d, want 0 (blob existed; nothing to upload)", patchHits.Load())
|
|
}
|
|
if putHits.Load() != 0 {
|
|
t.Errorf("PUT hits = %d, want 0 (blob existed; nothing to upload)", putHits.Load())
|
|
}
|
|
}
|