ollama source for Momentry Core verification
This commit is contained in:
25
internal/cloud/policy.go
Normal file
25
internal/cloud/policy.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const DisabledMessagePrefix = "ollama cloud is disabled"
|
||||
|
||||
// Status returns whether cloud is disabled and the source of the decision.
|
||||
// Source is one of: "none", "env", "config", "both".
|
||||
func Status() (disabled bool, source string) {
|
||||
return envconfig.NoCloud(), envconfig.NoCloudSource()
|
||||
}
|
||||
|
||||
func Disabled() bool {
|
||||
return envconfig.NoCloud()
|
||||
}
|
||||
|
||||
func DisabledError(operation string) string {
|
||||
if operation == "" {
|
||||
return DisabledMessagePrefix
|
||||
}
|
||||
|
||||
return DisabledMessagePrefix + ": " + operation
|
||||
}
|
||||
85
internal/cloud/policy_test.go
Normal file
85
internal/cloud/policy_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
configContent string
|
||||
disabled bool
|
||||
source string
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
{
|
||||
name: "env only",
|
||||
envValue: "1",
|
||||
disabled: true,
|
||||
source: "env",
|
||||
},
|
||||
{
|
||||
name: "config only",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "config",
|
||||
},
|
||||
{
|
||||
name: "both",
|
||||
envValue: "1",
|
||||
configContent: `{"disable_ollama_cloud": true}`,
|
||||
disabled: true,
|
||||
source: "both",
|
||||
},
|
||||
{
|
||||
name: "invalid config ignored",
|
||||
configContent: `{invalid json`,
|
||||
disabled: false,
|
||||
source: "none",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
if tt.configContent != "" {
|
||||
configPath := filepath.Join(home, ".ollama", "server.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
setTestHome(t, home)
|
||||
t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
|
||||
|
||||
disabled, source := Status()
|
||||
if disabled != tt.disabled {
|
||||
t.Fatalf("disabled: expected %v, got %v", tt.disabled, disabled)
|
||||
}
|
||||
if source != tt.source {
|
||||
t.Fatalf("source: expected %q, got %q", tt.source, source)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledError(t *testing.T) {
|
||||
if got := DisabledError(""); got != DisabledMessagePrefix {
|
||||
t.Fatalf("expected %q, got %q", DisabledMessagePrefix, got)
|
||||
}
|
||||
|
||||
want := DisabledMessagePrefix + ": remote inference is unavailable"
|
||||
if got := DisabledError("remote inference is unavailable"); got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
14
internal/cloud/test_home_test.go
Normal file
14
internal/cloud/test_home_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
115
internal/modelref/modelref.go
Normal file
115
internal/modelref/modelref.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelSource uint8
|
||||
|
||||
const (
|
||||
ModelSourceUnspecified ModelSource = iota
|
||||
ModelSourceLocal
|
||||
ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConflictingSourceSuffix = errors.New("use either :local or :cloud, not both")
|
||||
ErrModelRequired = errors.New("model is required")
|
||||
)
|
||||
|
||||
type ParsedRef struct {
|
||||
Original string
|
||||
Base string
|
||||
Source ModelSource
|
||||
}
|
||||
|
||||
func ParseRef(raw string) (ParsedRef, error) {
|
||||
var zero ParsedRef
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return zero, ErrModelRequired
|
||||
}
|
||||
|
||||
base, source, explicit := parseSourceSuffix(raw)
|
||||
if explicit {
|
||||
if _, _, nested := parseSourceSuffix(base); nested {
|
||||
return zero, fmt.Errorf("%w: %q", ErrConflictingSourceSuffix, raw)
|
||||
}
|
||||
}
|
||||
|
||||
return ParsedRef{
|
||||
Original: raw,
|
||||
Base: base,
|
||||
Source: source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func HasExplicitCloudSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceCloud
|
||||
}
|
||||
|
||||
func HasExplicitLocalSource(raw string) bool {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
return err == nil && parsedRef.Source == ModelSourceLocal
|
||||
}
|
||||
|
||||
func StripCloudSourceTag(raw string) (string, bool) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil || parsedRef.Source != ModelSourceCloud {
|
||||
return strings.TrimSpace(raw), false
|
||||
}
|
||||
|
||||
return parsedRef.Base, true
|
||||
}
|
||||
|
||||
func NormalizePullName(raw string) (string, bool, error) {
|
||||
parsedRef, err := ParseRef(raw)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if parsedRef.Source != ModelSourceCloud {
|
||||
return parsedRef.Base, false, nil
|
||||
}
|
||||
|
||||
return toLegacyCloudPullName(parsedRef.Base), true, nil
|
||||
}
|
||||
|
||||
func toLegacyCloudPullName(base string) string {
|
||||
if hasExplicitTag(base) {
|
||||
return base + "-cloud"
|
||||
}
|
||||
|
||||
return base + ":cloud"
|
||||
}
|
||||
|
||||
func hasExplicitTag(name string) bool {
|
||||
lastSlash := strings.LastIndex(name, "/")
|
||||
lastColon := strings.LastIndex(name, ":")
|
||||
return lastColon > lastSlash
|
||||
}
|
||||
|
||||
func parseSourceSuffix(raw string) (string, ModelSource, bool) {
|
||||
idx := strings.LastIndex(raw, ":")
|
||||
if idx >= 0 {
|
||||
suffixRaw := strings.TrimSpace(raw[idx+1:])
|
||||
suffix := strings.ToLower(suffixRaw)
|
||||
|
||||
switch suffix {
|
||||
case "cloud":
|
||||
return raw[:idx], ModelSourceCloud, true
|
||||
case "local":
|
||||
return raw[:idx], ModelSourceLocal, true
|
||||
}
|
||||
|
||||
if !strings.Contains(suffixRaw, "/") && strings.HasSuffix(suffix, "-cloud") {
|
||||
return raw[:idx+1] + suffixRaw[:len(suffixRaw)-len("-cloud")], ModelSourceCloud, true
|
||||
}
|
||||
}
|
||||
|
||||
return raw, ModelSourceUnspecified, false
|
||||
}
|
||||
268
internal/modelref/modelref_test.go
Normal file
268
internal/modelref/modelref_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package modelref
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantErr error
|
||||
wantCloud bool
|
||||
wantLocal bool
|
||||
wantStripped string
|
||||
wantStripOK bool
|
||||
}{
|
||||
{
|
||||
name: "cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantCloud: true,
|
||||
wantStripped: "gpt-oss:20b",
|
||||
wantStripOK: true,
|
||||
},
|
||||
{
|
||||
name: "local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantLocal: true,
|
||||
wantStripped: "qwen3:8b:local",
|
||||
},
|
||||
{
|
||||
name: "no source suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "llama3.2",
|
||||
},
|
||||
{
|
||||
name: "bare cloud name is not explicit cloud",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "my-cloud-model",
|
||||
},
|
||||
{
|
||||
name: "slash in suffix blocks legacy cloud parsing",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantStripped: "foo:bar-cloud/baz",
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: " ",
|
||||
wantErr: ErrModelRequired,
|
||||
wantSource: ModelSourceUnspecified,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseRef(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("ParseRef(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRef(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if got.Base != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", got.Base, tt.wantBase)
|
||||
}
|
||||
|
||||
if got.Source != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", got.Source, tt.wantSource)
|
||||
}
|
||||
|
||||
if HasExplicitCloudSource(tt.input) != tt.wantCloud {
|
||||
t.Fatalf("HasExplicitCloudSource(%q) = %v, want %v", tt.input, HasExplicitCloudSource(tt.input), tt.wantCloud)
|
||||
}
|
||||
|
||||
if HasExplicitLocalSource(tt.input) != tt.wantLocal {
|
||||
t.Fatalf("HasExplicitLocalSource(%q) = %v, want %v", tt.input, HasExplicitLocalSource(tt.input), tt.wantLocal)
|
||||
}
|
||||
|
||||
stripped, ok := StripCloudSourceTag(tt.input)
|
||||
if ok != tt.wantStripOK {
|
||||
t.Fatalf("StripCloudSourceTag(%q) ok = %v, want %v", tt.input, ok, tt.wantStripOK)
|
||||
}
|
||||
if stripped != tt.wantStripped {
|
||||
t.Fatalf("StripCloudSourceTag(%q) base = %q, want %q", tt.input, stripped, tt.wantStripped)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantCloud bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "explicit local strips source",
|
||||
input: "gpt-oss:20b:local",
|
||||
wantName: "gpt-oss:20b",
|
||||
},
|
||||
{
|
||||
name: "explicit cloud with size maps to legacy dash cloud tag",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud with size remains stable",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantName: "gpt-oss:20b-cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "explicit cloud without tag maps to cloud tag",
|
||||
input: "qwen3:cloud",
|
||||
wantName: "qwen3:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "host port without tag keeps host port and appends cloud tag",
|
||||
input: "localhost:11434/library/foo:cloud",
|
||||
wantName: "localhost:11434/library/foo:cloud",
|
||||
wantCloud: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting source suffixes fail",
|
||||
input: "foo:cloud:local",
|
||||
wantErr: ErrConflictingSourceSuffix,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotName, gotCloud, err := NormalizePullName(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("NormalizePullName(%q) error = %v, want %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NormalizePullName(%q) returned error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if gotName != tt.wantName {
|
||||
t.Fatalf("normalized name = %q, want %q", gotName, tt.wantName)
|
||||
}
|
||||
if gotCloud != tt.wantCloud {
|
||||
t.Fatalf("cloud = %v, want %v", gotCloud, tt.wantCloud)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSourceSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantBase string
|
||||
wantSource ModelSource
|
||||
wantExplicit bool
|
||||
}{
|
||||
{
|
||||
name: "explicit cloud suffix",
|
||||
input: "gpt-oss:20b:cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit local suffix",
|
||||
input: "qwen3:8b:local",
|
||||
wantBase: "qwen3:8b",
|
||||
wantSource: ModelSourceLocal,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix on tag",
|
||||
input: "gpt-oss:20b-cloud",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix does not match model segment",
|
||||
input: "my-cloud-model",
|
||||
wantBase: "my-cloud-model",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "legacy cloud suffix blocked when suffix includes slash",
|
||||
input: "foo:bar-cloud/baz",
|
||||
wantBase: "foo:bar-cloud/baz",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "unknown suffix is not explicit source",
|
||||
input: "gpt-oss:clod",
|
||||
wantBase: "gpt-oss:clod",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase suffix is accepted",
|
||||
input: "gpt-oss:20b:CLOUD",
|
||||
wantBase: "gpt-oss:20b",
|
||||
wantSource: ModelSourceCloud,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "no suffix",
|
||||
input: "llama3.2",
|
||||
wantBase: "llama3.2",
|
||||
wantSource: ModelSourceUnspecified,
|
||||
wantExplicit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBase, gotSource, gotExplicit := parseSourceSuffix(tt.input)
|
||||
if gotBase != tt.wantBase {
|
||||
t.Fatalf("base = %q, want %q", gotBase, tt.wantBase)
|
||||
}
|
||||
if gotSource != tt.wantSource {
|
||||
t.Fatalf("source = %v, want %v", gotSource, tt.wantSource)
|
||||
}
|
||||
if gotExplicit != tt.wantExplicit {
|
||||
t.Fatalf("explicit = %v, want %v", gotExplicit, tt.wantExplicit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
94
internal/orderedmap/orderedmap.go
Normal file
94
internal/orderedmap/orderedmap.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"iter"
|
||||
|
||||
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||
)
|
||||
|
||||
// Map is a generic ordered map that maintains insertion order.
|
||||
type Map[K comparable, V any] struct {
|
||||
om *orderedmap.OrderedMap[K, V]
|
||||
}
|
||||
|
||||
// New creates a new empty ordered map.
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
om: orderedmap.New[K, V](),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
if m == nil || m.om == nil {
|
||||
var zero V
|
||||
return zero, false
|
||||
}
|
||||
return m.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||
// but its position in the iteration order is preserved. If the key is new,
|
||||
// it is appended to the end.
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.om == nil {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
}
|
||||
m.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of entries.
|
||||
func (m *Map[K, V]) Len() int {
|
||||
if m == nil || m.om == nil {
|
||||
return 0
|
||||
}
|
||||
return m.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
if m == nil || m.om == nil {
|
||||
return
|
||||
}
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
if !yield(pair.Key, pair.Value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToMap converts to a regular Go map.
|
||||
// Note: The resulting map does not preserve order.
|
||||
func (m *Map[K, V]) ToMap() map[K]V {
|
||||
if m == nil || m.om == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[K]V, m.om.Len())
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
result[pair.Key] = pair.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||
if m == nil || m.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.om)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||
// order of keys in the JSON input.
|
||||
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
return json.Unmarshal(data, &m.om)
|
||||
}
|
||||
348
internal/orderedmap/orderedmap_test.go
Normal file
348
internal/orderedmap/orderedmap_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMap_BasicOperations(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Test empty map
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||
}
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on empty map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value, got %d", v)
|
||||
}
|
||||
|
||||
// Test Set and Get
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("b")
|
||||
if !ok || v != 2 {
|
||||
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("c")
|
||||
if !ok || v != 3 {
|
||||
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
// Test updating existing key preserves position
|
||||
m.Set("a", 10)
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 10 {
|
||||
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
m.Set("b", 4)
|
||||
|
||||
// Verify iteration order matches insertion order
|
||||
var keys []string
|
||||
var values []int
|
||||
for k, v := range m.All() {
|
||||
keys = append(keys, k)
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
expectedKeys := []string{"z", "a", "m", "b"}
|
||||
expectedValues := []int{1, 2, 3, 4}
|
||||
|
||||
if !slices.Equal(keys, expectedKeys) {
|
||||
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||
}
|
||||
if !slices.Equal(values, expectedValues) {
|
||||
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
m.Set("first", 1)
|
||||
m.Set("second", 2)
|
||||
m.Set("third", 3)
|
||||
|
||||
// Update middle element
|
||||
m.Set("second", 20)
|
||||
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Order should still be first, second, third
|
||||
expected := []string{"first", "second", "third"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON should preserve insertion order, not alphabetical
|
||||
expected := `{"z":1,"a":2,"m":3}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||
// JSON with non-alphabetical key order
|
||||
jsonData := `{"z":1,"a":2,"m":3}`
|
||||
|
||||
m := New[string, int]()
|
||||
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
expected := []string{"z", "a", "m"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||
// Test that unmarshal -> marshal produces identical JSON
|
||||
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||
|
||||
m := New[string, string]()
|
||||
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != original {
|
||||
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_ToMap(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
|
||||
regular := m.ToMap()
|
||||
|
||||
if len(regular) != 2 {
|
||||
t.Errorf("expected len 2, got %d", len(regular))
|
||||
}
|
||||
if regular["a"] != 1 {
|
||||
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||
}
|
||||
if regular["b"] != 2 {
|
||||
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NilSafety(t *testing.T) {
|
||||
var m *Map[string, int]
|
||||
|
||||
// All operations should be safe on nil
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on nil map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value from nil map, got %d", v)
|
||||
}
|
||||
|
||||
// Set on nil is a no-op
|
||||
m.Set("a", 1)
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||
}
|
||||
|
||||
// All returns empty iterator
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) != 0 {
|
||||
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||
}
|
||||
|
||||
// ToMap returns nil
|
||||
if m.ToMap() != nil {
|
||||
t.Error("expected ToMap to return nil on nil map")
|
||||
}
|
||||
|
||||
// MarshalJSON returns null
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("expected {}, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NestedValues(t *testing.T) {
|
||||
m := New[string, any]()
|
||||
m.Set("string", "hello")
|
||||
m.Set("number", 42)
|
||||
m.Set("bool", true)
|
||||
m.Set("nested", map[string]int{"x": 1})
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
m.Set("d", 4)
|
||||
|
||||
// Collect only first 2
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
if len(keys) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
expected := []string{"a", "b"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_IntegerKeys(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(3, "three")
|
||||
m.Set(1, "one")
|
||||
m.Set(2, "two")
|
||||
|
||||
var keys []int
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Should preserve insertion order, not numerical order
|
||||
expected := []int{3, 1, 2}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("existing", 999)
|
||||
|
||||
// Unmarshal should replace contents
|
||||
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
_, ok := m.Get("existing")
|
||||
if ok {
|
||||
t.Error("existing key should be gone after unmarshal")
|
||||
}
|
||||
|
||||
v, ok := m.Get("new")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Create many keys in specific order
|
||||
keys := make([]string, 100)
|
||||
for i := range 100 {
|
||||
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||
if i >= 26 {
|
||||
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||
}
|
||||
}
|
||||
|
||||
for i, k := range keys {
|
||||
m.Set(k, i)
|
||||
}
|
||||
|
||||
// Verify order preserved
|
||||
var resultKeys []string
|
||||
for k := range m.All() {
|
||||
resultKeys = append(resultKeys, k)
|
||||
}
|
||||
|
||||
if !slices.Equal(keys, resultKeys) {
|
||||
t.Error("large map should preserve insertion order")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user