ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

13
cmd/background_unix.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
// Setpgid prevents the server from being killed when the parent process exits.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}

12
cmd/background_windows.go Normal file
View File

@@ -0,0 +1,12 @@
package cmd
import "syscall"
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
CreationFlags: 0x08000000,
HideWindow: true,
}
}

143
cmd/bench/README.md Normal file
View File

@@ -0,0 +1,143 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Warmup phase before timed epochs to stabilize measurements
* Time-to-first-token (TTFT) tracking per epoch
* Model metadata display (parameter size, quantization level, family)
* VRAM and CPU memory usage tracking via running process info
* Controlled prompt token length for reproducible benchmarks
* Benchstat and CSV output formats
## Building from Source
```
go build -o ollama-bench ./cmd/bench
./ollama-bench -model gemma3 -epochs 6 -format csv
```
Using Go Run (without building)
```
go run ./cmd/bench -model gemma3 -epochs 3
```
## Usage
### Basic Example
```
./ollama-bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Controlled Prompt Length
```
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
```
### Advanced Example
```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 6 |
| -max-tokens | Maximum tokens for model response | 200 |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | (default story prompt) |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -warmup | Number of warmup requests before timing | 1 |
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Benchstat Format (default)
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
```
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
```
Use with benchstat:
```
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
benchstat -col /step gemma3.bench
```
Compare two runs:
```
./ollama-bench -model gemma3 -epochs 6 > before.bench
# ... make changes ...
./ollama-bench -model gemma3 -epochs 6 > after.bench
benchstat before.bench after.bench
```
### CSV Format
Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
gemma3,prefill,128,78125.00,12800.00
gemma3,generate,512,19531.25,51200.00
gemma3,ttft,1,45123000,0
gemma3,load,1,1500000000,0
gemma3,total,1,2861047625,0
```
## Metrics Explained
The tool reports the following metrics for each epoch:
* **prefill**: Time spent processing the prompt (ns/token)
* **generate**: Time spent generating the response (ns/token)
* **ttft**: Time to first token -- latency from request start to first response content
* **load**: Model loading time (one-time cost)
* **total**: Total request duration
Additionally, the model info comment line (displayed once per model before epochs) includes:
* **Params**: Model parameter count (e.g., 4.3B)
* **Quant**: Quantization level (e.g., Q4_K_M)
* **Family**: Model family (e.g., gemma3)
* **Size**: Total model memory in bytes
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)

561
cmd/bench/bench.go Normal file
View File

@@ -0,0 +1,561 @@
package main
import (
"cmp"
"context"
"flag"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
warmup *int
promptTokens *int
numCtx *int
}
type Metrics struct {
Model string
Step string
Count int
Duration time.Duration
}
type ModelInfo struct {
Name string
ParameterSize string
QuantizationLevel string
Family string
SizeBytes int64
VRAMBytes int64
NumCtx int64
}
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
// Word list for generating prompts targeting a specific token count.
var promptWordList = []string{
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
"of", "pine", "trees", "across", "rolling", "hills", "toward",
"distant", "mountains", "covered", "with", "fresh", "snow",
"beneath", "clear", "blue", "sky", "children", "play", "near",
"old", "stone", "bridge", "that", "crosses", "winding", "river",
}
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
var tokensPerWord = 1.3
func generatePromptForTokenCount(targetTokens int, epoch int) string {
targetWords := int(float64(targetTokens) / tokensPerWord)
if targetWords < 1 {
targetWords = 1
}
// Vary the starting offset by epoch to defeat KV cache prefix matching
offset := epoch * 7 // stride by a prime to get good distribution
n := len(promptWordList)
words := make([]string, targetWords)
for i := range words {
words[i] = promptWordList[((i+offset)%n+n)%n]
}
return strings.Join(words, " ")
}
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
if actualTokens <= 0 || wordCount <= 0 {
return
}
tokensPerWord = float64(actualTokens) / float64(wordCount)
newWords := int(float64(targetTokens) / tokensPerWord)
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
}
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
options["num_ctx"] = *fOpt.numCtx
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
prompt := *fOpt.prompt
if *fOpt.promptTokens > 0 {
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
} else {
// Vary the prompt per epoch to defeat KV cache prefix matching
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
}
req := &api.GenerateRequest{
Model: model,
Prompt: prompt,
Raw: true,
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Images = []api.ImageData{imgData}
}
return req
}
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
info := ModelInfo{Name: model}
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
return info
}
info.ParameterSize = resp.Details.ParameterSize
info.QuantizationLevel = resp.Details.QuantizationLevel
info.Family = resp.Details.Family
return info
}
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
resp, err := client.ListRunning(ctx)
if err != nil {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
}
return 0, 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model {
return m.Size, m.SizeVRAM
}
}
for _, m := range resp.Models {
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return m.Size, m.SizeVRAM
}
}
return 0, 0
}
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
resp, err := client.ListRunning(ctx)
if err != nil {
return 0
}
for _, m := range resp.Models {
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
return int64(m.ContextLength)
}
}
return 0
}
func outputFormatHeader(w io.Writer, format string, verbose bool) {
switch format {
case "benchstat":
if verbose {
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
}
case "csv":
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
}
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
params := cmp.Or(info.ParameterSize, "unknown")
quant := cmp.Or(info.QuantizationLevel, "unknown")
family := cmp.Or(info.Family, "unknown")
memStr := ""
if info.SizeBytes > 0 {
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
}
ctxStr := ""
if info.NumCtx > 0 {
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
}
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
info.Name, params, quant, family, memStr, ctxStr)
}
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
m.Model, m.Step)
}
} else if m.Step == "ttft" {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
m.Model, m.Duration.Nanoseconds())
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "csv":
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
var tokensPerSec float64
if m.Count > 0 {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
}
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkModel(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
imgData, err = readImage(*fOpt.imageFile)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
return err
}
}
if *fOpt.debug && imgData != nil {
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
return err
}
var out io.Writer = os.Stdout
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
return err
}
defer f.Close()
out = f
}
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
// Log prompt-tokens info in debug mode
if *fOpt.debug && *fOpt.promptTokens > 0 {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
wordCount := len(strings.Fields(prompt))
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
}
for _, model := range models {
// Fetch model info
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
info := fetchModelInfo(infoCtx, client, model)
infoCancel()
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
for i := range *fOpt.warmup {
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
var warmupMetrics *api.Metrics
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if resp.Done {
warmupMetrics = &resp.Metrics
}
return nil
})
cancel()
if err != nil {
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
} else {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
}
// Calibrate prompt token count on last warmup run
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
wordCount := len(strings.Fields(prompt))
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
}
}
}
// Fetch memory/context info once after warmup (model is loaded and stable)
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
info.NumCtx = int64(*fOpt.numCtx)
} else {
info.NumCtx = fetchContextLength(memCtx, client, model)
}
memCancel()
outputModelInfo(out, *fOpt.format, info)
// Timed epoch loop
shortCount := 0
for epoch := range *fOpt.epochs {
var responseMetrics *api.Metrics
var ttft time.Duration
short := false
// Retry loop: if the model hits a stop token before max-tokens,
// retry with a different prompt (up to maxRetries times).
const maxRetries = 3
for attempt := range maxRetries + 1 {
responseMetrics = nil
ttft = 0
var ttftOnce sync.Once
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
requestStart := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
}
// Capture TTFT on first content
ttftOnce.Do(func() {
if resp.Response != "" || resp.Thinking != "" {
ttft = time.Since(requestStart)
}
})
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
cancel()
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
} else {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
}
break
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
break
}
// Check if the response was shorter than requested
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
if !short || attempt == maxRetries {
break
}
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
}
}
if err != nil || responseMetrics == nil {
continue
}
if short {
shortCount++
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
}
}
metrics := []Metrics{
{
Model: model,
Step: "prefill",
Count: responseMetrics.PromptEvalCount,
Duration: responseMetrics.PromptEvalDuration,
},
{
Model: model,
Step: "generate",
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "ttft",
Count: 1,
Duration: ttft,
},
{
Model: model,
Step: "load",
Count: 1,
Duration: responseMetrics.LoadDuration,
},
{
Model: model,
Step: "total",
Count: 1,
Duration: responseMetrics.TotalDuration,
},
}
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.debug && *fOpt.promptTokens > 0 {
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
}
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
if shortCount > 0 {
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
}
// Unload model before moving to the next one
unloadModel(client, model, *fOpt.timeout)
}
return nil
}
func unloadModel(client *api.Client, model string, timeout int) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
zero := api.Duration{Duration: 0}
req := &api.GenerateRequest{
Model: model,
KeepAlive: &zero,
}
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
return nil
})
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return api.ImageData(data), nil
}
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
}
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
}
flag.Parse()
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
if len(*fOpt.models) == 0 {
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
flag.Usage()
return
}
BenchmarkModel(fOpt)
}

1410
cmd/bench/bench_test.go Normal file

File diff suppressed because it is too large Load Diff

2583
cmd/cmd.go Normal file

File diff suppressed because it is too large Load Diff

305
cmd/cmd_launcher_test.go Normal file
View File

@@ -0,0 +1,305 @@
package cmd
import (
"context"
"testing"
"github.com/spf13/cobra"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/cmd/tui"
)
func setCmdTestHome(t *testing.T, dir string) {
t.Helper()
t.Setenv("HOME", dir)
t.Setenv("USERPROFILE", dir)
}
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
t.Helper()
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
t.Fatalf("did not expect run-model resolution: %+v", req)
return "", nil
}
}
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
t.Helper()
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
t.Fatalf("did not expect integration launch: %+v", req)
return nil
}
}
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
t.Helper()
return func(cmd *cobra.Command, model string) error {
t.Fatalf("did not expect chat launch: %s", model)
return nil
}
}
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
wantModel string
}{
{
name: "enter uses saved model flow",
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
wantModel: "qwen3:8b",
},
{
name: "right forces picker",
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
wantForce: true,
wantModel: "glm-5:cloud",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.RunModelRequest
var launched string
prefetchedAccount := &launch.AccountState{}
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
if state.AccountState != prefetchedAccount {
t.Fatalf("prefetched account state was not piped to menu state")
}
return runMenu(state)
},
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
gotReq = req
return tt.wantModel, nil
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: func(cmd *cobra.Command, model string) error {
launched = model
return nil
},
accountState: func() *launch.AccountState {
return prefetchedAccount
},
accountStateUpdates: accountUpdates,
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.ForcePicker != tt.wantForce {
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
}
if gotReq.AccountState != prefetchedAccount {
t.Fatalf("expected prefetched account state to be passed to run model request")
}
if gotReq.AccountStateUpdates == nil {
t.Fatalf("expected account state updates to be passed to run model request")
}
if launched != tt.wantModel {
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
}
if got := config.LastSelection(); got != "run" {
t.Fatalf("expected last selection to be run, got %q", got)
}
})
}
}
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
tests := []struct {
name string
action tui.TUIAction
wantForce bool
}{
{
name: "enter launches integration",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
},
{
name: "right forces configure",
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
wantForce: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setCmdTestHome(t, t.TempDir())
var menuCalls int
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
menuCalls++
if menuCalls == 1 {
return tt.action, nil
}
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
}
var gotReq launch.IntegrationLaunchRequest
prefetchedAccount := &launch.AccountState{}
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
deps := launcherDeps{
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
return &launch.LauncherState{}, nil
},
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
if state.AccountState != prefetchedAccount {
t.Fatalf("prefetched account state was not piped to menu state")
}
return runMenu(state)
},
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
gotReq = req
return nil
},
runModel: unexpectedModelLaunch(t),
accountState: func() *launch.AccountState {
return prefetchedAccount
},
accountStateUpdates: accountUpdates,
}
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for {
continueLoop, err := runInteractiveTUIStep(cmd, deps)
if err != nil {
t.Fatalf("unexpected step error: %v", err)
}
if !continueLoop {
break
}
}
if gotReq.Name != "claude" {
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
}
if gotReq.ForceConfigure != tt.wantForce {
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
}
if gotReq.AccountState != prefetchedAccount {
t.Fatalf("expected prefetched account state to be passed to integration request")
}
if gotReq.AccountStateUpdates == nil {
t.Fatalf("expected account state updates to be passed to integration request")
}
if got := config.LastSelection(); got != "claude" {
t.Fatalf("expected last selection to be claude, got %q", got)
}
})
}
}
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
return "", launch.ErrCancelled
},
launchIntegration: unexpectedIntegrationLaunch(t),
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}
func TestRunLauncherAction_GUIAppsExitTUILoop(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
for _, integration := range []string{"codex-app", "vscode"} {
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: integration}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error for %s, got %v", integration, err)
}
if continueLoop {
t.Fatalf("expected %s launch to exit the TUI loop (return false)", integration)
}
}
// Other integrations should continue the TUI loop (return true).
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return nil
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !continueLoop {
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
}
}
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
setCmdTestHome(t, t.TempDir())
cmd := &cobra.Command{}
cmd.SetContext(context.Background())
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
buildState: nil,
runMenu: nil,
resolveRunModel: unexpectedRunModelResolution(t),
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
return launch.ErrCancelled
},
runModel: unexpectedModelLaunch(t),
})
if err != nil {
t.Fatalf("expected nil error on cancellation, got %v", err)
}
if !continueLoop {
t.Fatal("expected cancellation to continue the menu loop")
}
}

2361
cmd/cmd_test.go Normal file

File diff suppressed because it is too large Load Diff

284
cmd/config/config.go Normal file
View File

@@ -0,0 +1,284 @@
// Package config provides integration configuration for external coding tools
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/cmd/internal/fileutil"
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Onboarded bool `json:"onboarded,omitempty"`
}
// IntegrationConfig is the persisted config for one integration.
type IntegrationConfig = integration
type config struct {
Integrations map[string]*integration `json:"integrations"`
LastModel string `json:"last_model,omitempty"`
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
}
func configPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config.json"), nil
}
func legacyConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "config", "config.json"), nil
}
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
func migrateConfig() (bool, error) {
oldPath, err := legacyConfigPath()
if err != nil {
return false, err
}
oldData, err := os.ReadFile(oldPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
// Ignore legacy files with invalid JSON and continue startup.
if !json.Valid(oldData) {
return false, nil
}
newPath, err := configPath()
if err != nil {
return false, err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return false, err
}
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
return false, fmt.Errorf("write new config: %w", err)
}
_ = os.Remove(oldPath)
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
return true, nil
}
func load() (*config, error) {
path, err := configPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil && os.IsNotExist(err) {
if migrated, merr := migrateConfig(); merr == nil && migrated {
data, err = os.ReadFile(path)
}
}
if err != nil {
if os.IsNotExist(err) {
return &config{Integrations: make(map[string]*integration)}, nil
}
return nil, err
}
var cfg config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
}
if cfg.Integrations == nil {
cfg.Integrations = make(map[string]*integration)
}
return &cfg, nil
}
func save(cfg *config) error {
path, err := configPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(path, data)
}
func SaveIntegration(appName string, models []string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
var onboarded bool
if existing != nil {
aliases = existing.Aliases
onboarded = existing.Onboarded
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
Onboarded: onboarded,
}
return save(cfg)
}
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
func MarkIntegrationOnboarded(appName string) error {
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
existing.Onboarded = true
cfg.Integrations[key] = existing
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return ""
}
return integrationConfig.Models[0]
}
// IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string {
integrationConfig, err := LoadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return nil
}
return integrationConfig.Models
}
// LastModel returns the last model that was run, or empty string if none.
func LastModel() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastModel
}
// SetLastModel saves the last model that was run.
func SetLastModel(model string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastModel = model
return save(cfg)
}
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
func LastSelection() string {
cfg, err := load()
if err != nil {
return ""
}
return cfg.LastSelection
}
// SetLastSelection saves the last menu selection ("run" or integration name).
func SetLastSelection(selection string) error {
cfg, err := load()
if err != nil {
return err
}
cfg.LastSelection = selection
return save(cfg)
}
// LoadIntegration returns the saved config for one integration.
func LoadIntegration(appName string) (*integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return integrationConfig, nil
}
// SaveAliases replaces the saved aliases for one integration.
func SaveAliases(appName string, aliases map[string]string) error {
if appName == "" {
return errors.New("app name cannot be empty")
}
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
// Replace aliases entirely (not merge) so deletions are persisted
existing.Aliases = aliases
cfg.Integrations[key] = existing
return save(cfg)
}
func listIntegrations() ([]integration, error) {
cfg, err := load()
if err != nil {
return nil, err
}
result := make([]integration, 0, len(cfg.Integrations))
for _, integrationConfig := range cfg.Integrations {
result = append(result, *integrationConfig)
}
return result, nil
}

View File

@@ -0,0 +1,641 @@
package config
import (
"errors"
"os"
"path/filepath"
"testing"
)
func TestSetAliases_CloudModel(t *testing.T) {
// Test the SetAliases logic by checking the alias map behavior
aliases := map[string]string{
"primary": "kimi-k2.5:cloud",
"fast": "kimi-k2.5:cloud",
}
// Verify fast is set (cloud model behavior)
if aliases["fast"] == "" {
t.Error("cloud model should have fast alias set")
}
if aliases["fast"] != aliases["primary"] {
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
}
}
func TestSetAliases_LocalModel(t *testing.T) {
aliases := map[string]string{
"primary": "llama3.2:latest",
}
// Simulate local model behavior: fast should be empty
delete(aliases, "fast")
if aliases["fast"] != "" {
t.Error("local model should have empty fast alias")
}
}
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save with both primary and fast
initial := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
if err := SaveAliases("claude", initial); err != nil {
t.Fatalf("failed to save initial aliases: %v", err)
}
// Verify both are saved
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
}
// Now save without fast (simulating switch to local model)
updated := map[string]string{
"primary": "local-model",
// fast intentionally missing
}
if err := SaveAliases("claude", updated); err != nil {
t.Fatalf("failed to save updated aliases: %v", err)
}
// Verify fast is GONE (not merged/preserved)
loaded, err = LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
}
}
func TestSaveAliases_PreservesModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// First save integration with models
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
t.Fatalf("failed to save integration: %v", err)
}
// Then update aliases
aliases := map[string]string{"primary": "new-model"}
if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save aliases: %v", err)
}
// Verify models are preserved
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
t.Errorf("models should be preserved, got %v", loaded.Models)
}
}
// TestSaveAliases_EmptyMap clears all aliases
func TestSaveAliases_EmptyMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save empty map
if err := SaveAliases("claude", map[string]string{}); err != nil {
t.Fatalf("failed to save empty: %v", err)
}
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) != 0 {
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_NilMap handles nil gracefully
func TestSaveAliases_NilMap(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save with aliases first
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Save nil map - should clear aliases
if err := SaveAliases("claude", nil); err != nil {
t.Fatalf("failed to save nil: %v", err)
}
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if len(loaded.Aliases) > 0 {
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
}
}
// TestSaveAliases_EmptyAppName returns error
func TestSaveAliases_EmptyAppName(t *testing.T) {
err := SaveAliases("", map[string]string{"primary": "model"})
if err == nil {
t.Error("expected error for empty app name")
}
}
func TestSaveAliases_CaseInsensitive(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Load with different case
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model1" {
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
}
// Update with different case
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to update: %v", err)
}
loaded, err = LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load after update: %v", err)
}
if loaded.Aliases["primary"] != "model2" {
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
}
}
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
func TestSaveAliases_CreatesIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases for non-existent integration
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
loaded, err := LoadIntegration("newintegration")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigureAliases_AliasMap(t *testing.T) {
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
aliases := make(map[string]string)
aliases["primary"] = "cloud-model"
// Simulate cloud model behavior
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
}
})
t.Run("cloud model preserves custom fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "cloud-model",
"fast": "custom-fast-model",
}
// Simulate cloud model behavior - should preserve existing fast
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "custom-fast-model" {
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
}
})
t.Run("local model clears fast", func(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
"fast": "should-be-cleared",
}
// Simulate local model behavior
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
}
})
t.Run("switching cloud to local clears fast", func(t *testing.T) {
// Start with cloud config
aliases := map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}
// Switch to local
aliases["primary"] = "local-model"
isCloud := false
if !isCloud {
delete(aliases, "fast")
}
if aliases["fast"] != "" {
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
}
if aliases["primary"] != "local-model" {
t.Errorf("primary should be updated, got %q", aliases["primary"])
}
})
t.Run("switching local to cloud sets fast", func(t *testing.T) {
// Start with local config (no fast)
aliases := map[string]string{
"primary": "local-model",
}
// Switch to cloud
aliases["primary"] = "cloud-model"
isCloud := true
if isCloud {
if aliases["fast"] == "" {
aliases["fast"] = aliases["primary"]
}
}
if aliases["fast"] != "cloud-model" {
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
}
})
}
func TestSetAliases_PrefixMapping(t *testing.T) {
// This tests the expected mapping without needing a real client
aliases := map[string]string{
"primary": "my-cloud-model",
"fast": "my-fast-model",
}
expectedMappings := map[string]string{
"claude-sonnet-": aliases["primary"],
"claude-haiku-": aliases["fast"],
}
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
t.Errorf("claude-sonnet- should map to primary")
}
if expectedMappings["claude-haiku-"] != "my-fast-model" {
t.Errorf("claude-haiku- should map to fast")
}
}
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
aliases := map[string]string{
"primary": "local-model",
// fast is empty/missing - indicates local model
}
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
// Verify the logic: when fast is empty, we should delete
if aliases["fast"] != "" {
t.Error("fast should be empty for local model")
}
// Verify we have the right prefixes to delete
if len(prefixesToDelete) != 2 {
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
}
}
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server fails, config should NOT be saved
serverErr := errors.New("server unavailable")
if serverErr == nil {
t.Error("config should NOT be saved when server fails")
}
}
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate: server succeeds, config should be saved
var serverErr error
if serverErr != nil {
t.Fatal("server should succeed")
}
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
t.Fatalf("saveAliases failed: %v", err)
}
// Verify it was actually saved
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != "model" {
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
}
}
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Write config with extra fields
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
os.MkdirAll(filepath.Dir(configPath), 0o755)
// Note: Our config struct only has Integrations, so top-level unknown fields
// won't be preserved by our current implementation. This test documents that.
initialConfig := `{
"integrations": {
"claude": {
"models": ["model1"],
"aliases": {"primary": "model1"},
"unknownField": "should be lost"
}
},
"topLevelUnknown": "will be lost"
}`
os.WriteFile(configPath, []byte(initialConfig), 0o644)
// Update aliases
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
t.Fatalf("failed to save: %v", err)
}
// Read raw file to check
data, _ := os.ReadFile(configPath)
content := string(data)
// models should be preserved
if !contains(content, "model1") {
t.Error("models should be preserved")
}
// primary should be updated
if !contains(content, "model2") {
t.Error("primary should be updated to model2")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func TestModelNameEdgeCases(t *testing.T) {
testCases := []struct {
name string
model string
}{
{"simple", "llama3.2"},
{"with tag", "llama3.2:latest"},
{"with cloud tag", "kimi-k2.5:cloud"},
{"with namespace", "library/llama3.2"},
{"with dots", "glm-4.7-flash"},
{"with numbers", "qwen3:8b"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
aliases := map[string]string{"primary": tc.model}
if err := SaveAliases("claude", aliases); err != nil {
t.Fatalf("failed to save model %q: %v", tc.model, err)
}
loaded, err := LoadIntegration("claude")
if err != nil {
t.Fatalf("failed to load: %v", err)
}
if loaded.Aliases["primary"] != tc.model {
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
}
})
}
}
func TestSwitchingScenarios(t *testing.T) {
t.Run("cloud to local removes fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
// Switch to local (no fast)
if err := SaveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "" {
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
}
if loaded.Aliases["primary"] != "local-model" {
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
}
})
t.Run("local to cloud adds fast", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial local config
if err := SaveAliases("claude", map[string]string{
"primary": "local-model",
}); err != nil {
t.Fatal(err)
}
// Switch to cloud (with fast)
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model",
"fast": "cloud-model",
}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["fast"] != "cloud-model" {
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
}
})
t.Run("cloud to different cloud updates both", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial cloud config
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-1",
"fast": "cloud-model-1",
}); err != nil {
t.Fatal(err)
}
// Switch to different cloud
if err := SaveAliases("claude", map[string]string{
"primary": "cloud-model-2",
"fast": "cloud-model-2",
}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != "cloud-model-2" {
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
}
if loaded.Aliases["fast"] != "cloud-model-2" {
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
}
})
}
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Save aliases with one model
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
t.Fatal(err)
}
// Save integration with same model (this is the pattern we use)
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
if loaded.Aliases["primary"] != loaded.Models[0] {
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
}
})
t.Run("out of sync config is detectable", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Simulate out-of-sync state (like manual edit or bug)
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
t.Fatal(err)
}
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
// They should be different (this is the bug state)
if loaded.Models[0] == loaded.Aliases["primary"] {
t.Error("expected out-of-sync state for this test")
}
// The fix: when updating aliases, also update models
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ = LoadIntegration("claude")
if loaded.Models[0] != loaded.Aliases["primary"] {
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
loaded.Models[0], loaded.Aliases["primary"])
}
})
t.Run("updating primary alias updates models too", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Initial state
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
t.Fatal(err)
}
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
t.Fatal(err)
}
// Update aliases AND models together
newAliases := map[string]string{"primary": "updated-model"}
if err := SaveAliases("claude", newAliases); err != nil {
t.Fatal(err)
}
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
t.Fatal(err)
}
loaded, _ := LoadIntegration("claude")
if loaded.Models[0] != "updated-model" {
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
}
if loaded.Aliases["primary"] != "updated-model" {
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
}
})
}

530
cmd/config/config_test.go Normal file
View File

@@ -0,0 +1,530 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
func setTestHome(t *testing.T, dir string) {
t.Setenv("HOME", dir)
t.Setenv("TMPDIR", dir)
t.Setenv("USERPROFILE", dir)
}
func TestIntegrationConfig(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("save and load round-trip", func(t *testing.T) {
models := []string{"llama3.2", "mistral", "qwen2.5"}
if err := SaveIntegration("claude", models); err != nil {
t.Fatal(err)
}
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if len(config.Models) != len(models) {
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
}
for i, m := range models {
if config.Models[i] != m {
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
}
}
})
t.Run("save and load aliases", func(t *testing.T) {
models := []string{"llama3.2"}
if err := SaveIntegration("claude", models); err != nil {
t.Fatal(err)
}
aliases := map[string]string{
"primary": "llama3.2:70b",
"fast": "llama3.2:8b",
}
if err := SaveAliases("claude", aliases); err != nil {
t.Fatal(err)
}
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases == nil {
t.Fatal("expected aliases to be saved")
}
for k, v := range aliases {
if config.Aliases[k] != v {
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
}
}
})
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
t.Fatal(err)
}
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
t.Fatal(err)
}
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
t.Fatal(err)
}
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
if config.Aliases["primary"] != "model-a" {
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
}
})
t.Run("defaultModel returns first model", func(t *testing.T) {
SaveIntegration("codex", []string{"model-a", "model-b"})
config, _ := LoadIntegration("codex")
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-a" {
t.Errorf("expected model-a, got %s", defaultModel)
}
})
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
config := &integration{Models: []string{}}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "" {
t.Errorf("expected empty string, got %s", defaultModel)
}
})
t.Run("app name is case-insensitive", func(t *testing.T) {
SaveIntegration("Claude", []string{"model-x"})
config, err := LoadIntegration("claude")
if err != nil {
t.Fatal(err)
}
defaultModel := ""
if len(config.Models) > 0 {
defaultModel = config.Models[0]
}
if defaultModel != "model-x" {
t.Errorf("expected model-x, got %s", defaultModel)
}
})
t.Run("multiple integrations in single file", func(t *testing.T) {
SaveIntegration("app1", []string{"model-1"})
SaveIntegration("app2", []string{"model-2"})
config1, _ := LoadIntegration("app1")
config2, _ := LoadIntegration("app2")
defaultModel1 := ""
if len(config1.Models) > 0 {
defaultModel1 = config1.Models[0]
}
defaultModel2 := ""
if len(config2.Models) > 0 {
defaultModel2 = config2.Models[0]
}
if defaultModel1 != "model-1" {
t.Errorf("expected model-1, got %s", defaultModel1)
}
if defaultModel2 != "model-2" {
t.Errorf("expected model-2, got %s", defaultModel2)
}
})
}
func TestListIntegrations(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty when no integrations", func(t *testing.T) {
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 0 {
t.Errorf("expected 0 integrations, got %d", len(configs))
}
})
t.Run("returns all saved integrations", func(t *testing.T) {
SaveIntegration("claude", []string{"model-1"})
SaveIntegration("droid", []string{"model-2"})
configs, err := listIntegrations()
if err != nil {
t.Fatal(err)
}
if len(configs) != 2 {
t.Errorf("expected 2 integrations, got %d", len(configs))
}
})
}
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
dir := filepath.Join(tmpDir, ".ollama")
os.MkdirAll(dir, 0o755)
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
_, err := LoadIntegration("test")
if err == nil {
t.Error("expected error for nonexistent integration in corrupted file")
}
}
func TestSaveIntegration_NilModels(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := SaveIntegration("test", nil); err != nil {
t.Fatalf("saveIntegration with nil models failed: %v", err)
}
config, err := LoadIntegration("test")
if err != nil {
t.Fatalf("loadIntegration failed: %v", err)
}
if config.Models == nil {
// nil is acceptable
} else if len(config.Models) != 0 {
t.Errorf("expected empty or nil models, got %v", config.Models)
}
}
func TestSaveIntegration_EmptyAppName(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
err := SaveIntegration("", []string{"model"})
if err == nil {
t.Error("expected error for empty app name, got nil")
}
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
}
}
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
_, err := LoadIntegration("nonexistent")
if err == nil {
t.Error("expected error for nonexistent integration, got nil")
}
if !os.IsNotExist(err) {
t.Logf("error type is os.ErrNotExist as expected: %v", err)
}
}
func TestConfigPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
path, err := configPath()
if err != nil {
t.Fatal(err)
}
expected := filepath.Join(tmpDir, ".ollama", "config.json")
if path != expected {
t.Errorf("expected %s, got %s", expected, path)
}
}
func TestLoad(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns empty config when file does not exist", func(t *testing.T) {
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg == nil {
t.Fatal("expected non-nil config")
}
if cfg.Integrations == nil {
t.Error("expected non-nil Integrations map")
}
if len(cfg.Integrations) != 0 {
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
}
})
t.Run("loads existing config", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["test"] == nil {
t.Fatal("expected test integration")
}
if len(cfg.Integrations["test"].Models) != 1 {
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
}
})
t.Run("returns error for corrupted JSON", func(t *testing.T) {
path, _ := configPath()
os.MkdirAll(filepath.Dir(path), 0o755)
os.WriteFile(path, []byte(`{corrupted`), 0o644)
_, err := load()
if err == nil {
t.Error("expected error for corrupted JSON")
}
})
}
func TestMigrateConfig(t *testing.T) {
t.Run("migrates legacy file to new location", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
newPath, _ := configPath()
got, err := os.ReadFile(newPath)
if err != nil {
t.Fatalf("new config not found: %v", err)
}
if string(got) != string(data) {
t.Errorf("content mismatch: got %s", got)
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("legacy file should have been removed")
}
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
t.Error("legacy directory should have been removed")
}
})
t.Run("no-op when no legacy file exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("expected no migration")
}
})
t.Run("skips corrupt legacy file", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("should not migrate corrupt file")
}
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
t.Error("corrupt legacy file should not have been deleted")
}
})
t.Run("new path takes precedence over legacy", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
newDir := filepath.Join(tmpDir, ".ollama")
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if _, ok := cfg.Integrations["new"]; !ok {
t.Error("expected new-path integration to be loaded")
}
if _, ok := cfg.Integrations["old"]; ok {
t.Error("legacy integration should not have been loaded")
}
})
t.Run("idempotent when called twice", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
migrated, err := migrateConfig()
if err != nil {
t.Fatal(err)
}
if migrated {
t.Error("second migration should be a no-op")
}
})
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
if _, err := migrateConfig(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
t.Error("directory with other files should not have been removed")
}
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
t.Error("other files in legacy directory should be untouched")
}
})
t.Run("save writes to new path after migration", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
// load triggers migration, then save should write to new path
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
t.Fatal(err)
}
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
if _, err := os.Stat(newPath); os.IsNotExist(err) {
t.Error("save should write to new path")
}
// old path should not be recreated
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
t.Error("save should not recreate legacy path")
}
})
t.Run("load triggers migration transparently", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
cfg, err := load()
if err != nil {
t.Fatal(err)
}
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
t.Error("migration via load() did not preserve data")
}
})
}
func TestSave(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("creates config file", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"test": {Models: []string{"model-a", "model-b"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
path, _ := configPath()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("config file was not created")
}
})
t.Run("round-trip preserves data", func(t *testing.T) {
cfg := &config{
Integrations: map[string]*integration{
"claude": {Models: []string{"llama3.2", "mistral"}},
"codex": {Models: []string{"qwen2.5"}},
},
}
if err := save(cfg); err != nil {
t.Fatal(err)
}
loaded, err := load()
if err != nil {
t.Fatal(err)
}
if len(loaded.Integrations) != 2 {
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
}
if loaded.Integrations["claude"] == nil {
t.Error("missing claude integration")
}
if len(loaded.Integrations["claude"].Models) != 2 {
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
}
})
}

5
cmd/editor_unix.go Normal file
View File

@@ -0,0 +1,5 @@
//go:build !windows
package cmd
const defaultEditor = "vi"

5
cmd/editor_windows.go Normal file
View File

@@ -0,0 +1,5 @@
//go:build windows
package cmd
const defaultEditor = "edit"

735
cmd/interactive.go Normal file
View File

@@ -0,0 +1,735 @@
package cmd
import (
"cmp"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"slices"
"strings"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
)
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
}
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "")
}
usageShortcuts := func() {
fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)")
fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)")
fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word")
fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word")
fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor")
fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor")
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt")
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
fmt.Fprintln(os.Stderr, "")
}
usageShow := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /show info Show details for this model")
fmt.Fprintln(os.Stderr, " /show license Show model license")
fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model")
fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model")
fmt.Fprintln(os.Stderr, " /show system Show system message")
fmt.Fprintln(os.Stderr, " /show template Show prompt template")
fmt.Fprintln(os.Stderr, "")
}
// only list out the most common parameters
usageParameters := func() {
fmt.Fprintln(os.Stderr, "Available Parameters:")
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter min_p <float> Pick token based on top token probability * min_p")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop <string> <string> ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var sb strings.Builder
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
scanner.Prompt.UseAlt = false
sb.Reset()
continue
case errors.Is(err, readline.ErrEditPrompt):
sb.Reset()
content, err := editInExternalEditor(line)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
continue
}
if strings.TrimSpace(content) == "" {
continue
}
scanner.Prefill = content
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
scanner.Prompt.UseAlt = true
continue
}
switch multiline {
case MultilineSystem:
opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, `"""`):
line := strings.TrimPrefix(line, `"""`)
line, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(line)
if !ok {
// no multiline terminating string; need more input
fmt.Fprintln(&sb)
multiline = MultilinePrompt
scanner.Prompt.UseAlt = true
}
case scanner.Pasting:
fmt.Fprintln(&sb, line)
continue
case strings.HasPrefix(line, "/list"):
args := strings.Fields(line)
if err := ListHandler(cmd, args[1:]); err != nil {
return err
}
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /load <modelname>")
continue
}
origOpts := opts.Copy()
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
opts.Model = args[1]
opts.Messages = []api.Message{}
opts.LoadedMessages = nil
fmt.Printf("Loading model '%s'\n", opts.Model)
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model})
if err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue
}
return err
}
applyShowResponseToRunOptions(&opts, info)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue
}
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := NewCreateRequest(args[1], opts)
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
continue
}
return err
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/clear"):
opts.Messages = []api.Message{}
if opts.System != "" {
newMessage := api.Message{Role: "system", Content: opts.System}
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Cleared session context")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
// TODO(drifkin): validate the level, could be model dependent
// though... It will also be validated on the server once a call is
// made.
thinkValue.Value = maybeLevel
}
opts.Think = &thinkValue
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
opts.Think = &api.ThinkValue{Value: false}
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
opts.Format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
usageParameters()
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
usageSet()
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
usageSet()
}
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.ShowRequest{
Name: opts.Model,
System: opts.System,
Options: opts.Options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
return err
}
switch args[1] {
case "info":
_ = showInfo(resp, false, os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified for this model.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
fmt.Println()
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf(" %-*s %v\n", 30, k, v)
}
fmt.Println()
}
case "system":
switch {
case opts.System != "":
fmt.Println(opts.System + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
usageShow()
}
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "set", "/set":
usageSet()
case "show", "/show":
usageShow()
case "shortcut", "shortcuts":
usageShortcuts()
}
} else {
usage()
}
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/"):
args := strings.Fields(line)
isFile := false
if opts.MultiModal {
for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) {
isFile = true
break
}
}
}
if !isFile {
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
continue
}
sb.WriteString(line)
default:
sb.WriteString(line)
}
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
if opts.MultiModal {
msg, images, err := extractFileData(sb.String())
if err != nil {
return err
}
newMessage.Content = msg
newMessage.Images = images
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
if strings.Contains(err.Error(), "does not support thinking") ||
strings.Contains(err.Error(), "invalid think value") {
fmt.Printf("error: %v\n", err)
sb.Reset()
continue
}
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
sb.Reset()
}
}
}
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
// Preserve explicit cloud intent for sessions started with `:cloud`.
// Cloud model metadata can return a source-less parent_model (for example
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
parentModel = ""
}
req := &api.CreateRequest{
Model: name,
From: cmp.Or(parentModel, opts.Model),
}
if opts.System != "" {
req.System = opts.System
}
if len(opts.Options) > 0 {
req.Parameters = opts.Options
}
messages := slices.Clone(opts.LoadedMessages)
messages = append(messages, opts.Messages...)
if len(messages) > 0 {
req.Messages = messages
}
return req
}
func normalizeFilePath(fp string) string {
return strings.NewReplacer(
"\\ ", " ", // Escaped space
"\\(", "(", // Escaped left parenthesis
"\\)", ")", // Escaped right parenthesis
"\\[", "[", // Escaped left square bracket
"\\]", "]", // Escaped right square bracket
"\\{", "{", // Escaped left curly brace
"\\}", "}", // Escaped right curly brace
"\\$", "$", // Escaped dollar sign
"\\&", "&", // Escaped ampersand
"\\;", ";", // Escaped semicolon
"\\'", "'", // Escaped single quote
"\\\\", "\\", // Escaped backslash
"\\*", "*", // Escaped asterisk
"\\?", "?", // Escaped question mark
"\\~", "~", // Escaped tilde
).Replace(fp)
}
func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
}
func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input)
var imgs []api.ImageData
for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
return "", imgs, err
}
ext := strings.ToLower(filepath.Ext(nfp))
switch ext {
case ".wav":
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
default:
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
}
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return strings.TrimSpace(input), imgs, nil
}
func editInExternalEditor(content string) (string, error) {
editor := envconfig.Editor()
if editor == "" {
editor = os.Getenv("VISUAL")
}
if editor == "" {
editor = os.Getenv("EDITOR")
}
if editor == "" {
editor = defaultEditor
}
// Check that the editor binary exists
name := strings.Fields(editor)[0]
if _, err := exec.LookPath(name); err != nil {
return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name)
}
tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt")
if err != nil {
return "", fmt.Errorf("creating temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if content != "" {
if _, err := tmpFile.WriteString(content); err != nil {
tmpFile.Close()
return "", fmt.Errorf("writing to temp file: %w", err)
}
}
tmpFile.Close()
args := strings.Fields(editor)
args = append(args, tmpFile.Name())
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("editor exited with error: %w", err)
}
data, err := os.ReadFile(tmpFile.Name())
if err != nil {
return "", fmt.Errorf("reading temp file: %w", err)
}
return strings.TrimRight(string(data), "\n"), nil
}
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid file type: %s", contentType)
}
info, err := file.Stat()
if err != nil {
return nil, err
}
var maxSize int64 = 100 * 1024 * 1024 // 100MB
if info.Size() > maxSize {
return nil, errors.New("file size exceeds maximum limit (100MB)")
}
buf = make([]byte, info.Size())
_, err = file.Seek(0, 0)
if err != nil {
return nil, err
}
_, err = io.ReadFull(file, buf)
if err != nil {
return nil, err
}
return buf, nil
}

116
cmd/interactive_test.go Normal file
View File

@@ -0,0 +1,116 @@
package cmd
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractFilenames(t *testing.T) {
// Unix style paths
input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
res := extractFileNames(input)
assert.Len(t, res, 7)
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.webp")
assert.Contains(t, res[6], "seven.WEBP")
assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
// Windows style paths
input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
d:\path with\spaces\thirteen.WEBP some ending
`
res = extractFileNames(input)
assert.Len(t, res, 13)
assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:")
assert.Contains(t, res[10], "eleven.webp")
assert.Contains(t, res[10], "c:")
assert.Contains(t, res[11], "twelve.WebP")
assert.Contains(t, res[11], "c:")
assert.Contains(t, res[12], "thirteen.WEBP")
assert.Contains(t, res[12], "d:")
}
// Ensure that file paths wrapped in single quotes are removed with the quotes.
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "img.jpg")
data := make([]byte, 600)
copy(data, []byte{
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xff, 0xd9,
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test image: %v", err)
}
input := "before '" + fp + "' after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}
func TestExtractFileDataWAV(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "sample.wav")
data := make([]byte, 600)
copy(data[:44], []byte{
'R', 'I', 'F', 'F',
0x58, 0x02, 0x00, 0x00, // file size - 8
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // fmt chunk size
0x01, 0x00, // PCM
0x01, 0x00, // mono
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
0x00, 0x7d, 0x00, 0x00, // byte rate
0x02, 0x00, // block align
0x10, 0x00, // 16-bit
'd', 'a', 't', 'a',
0x34, 0x02, 0x00, 0x00, // data size
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test audio: %v", err)
}
input := "before " + fp + " after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, "before after", cleaned)
}

View File

@@ -0,0 +1,176 @@
// Package fileutil provides small shared helpers for reading JSON files
// and writing config files with backup-on-overwrite semantics.
package fileutil
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
)
// Keep a bounded number of backups per file so config backups do not grow
// without limit. We keep the 5 most recent backups and do not pin the oldest.
const maxBackupsPerFile = 5
// ReadJSON reads a JSON object file into a generic map.
func ReadJSON(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
func copyFile(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, info.Mode().Perm())
}
// BackupDir returns the shared backup root used before overwriting files.
func BackupDir() string {
if home, err := os.UserHomeDir(); err == nil && home != "" {
return filepath.Join(home, ".ollama", "backup")
}
return filepath.Join(os.TempDir(), "ollama-backup")
}
func writeBackupCopy(srcPath string, integration string) (string, error) {
dir := BackupDir()
name := filepath.Base(srcPath)
if integration != "" {
dir = filepath.Join(dir, integration)
}
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", name, time.Now().Unix()))
if err := copyFile(srcPath, backupPath); err != nil {
return "", err
}
pruneOldBackups(dir, name, maxBackupsPerFile)
return backupPath, nil
}
// WriteWithBackup writes data to path via temp file + rename, backing up any
// existing file first. Callers may optionally pass one integration name to
// store backups under BackupDir()/.../<integration>/.
func WriteWithBackup(path string, data []byte, integration ...string) error {
backupIntegration := ""
if len(integration) > 0 {
backupIntegration = integration[0]
}
var backupPath string
// backup must be created before any writes to the target file
if existingContent, err := os.ReadFile(path); err == nil {
if bytes.Equal(existingContent, data) {
return nil
}
backupPath, err = writeBackupCopy(path, backupIntegration)
if err != nil {
return fmt.Errorf("backup failed: %w", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("create temp failed: %w", err)
}
tmpPath := tmp.Name()
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("write failed: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
_ = os.Remove(tmpPath)
return fmt.Errorf("sync failed: %w", err)
}
if err := tmp.Close(); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("close failed: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
_ = os.Remove(tmpPath)
if backupPath != "" {
_ = copyFile(backupPath, path)
}
return fmt.Errorf("rename failed: %w", err)
}
return nil
}
func pruneOldBackups(dir, name string, keep int) {
if keep < 1 {
return
}
entries, err := os.ReadDir(dir)
if err != nil {
return
}
type backupEntry struct {
name string
timestamp int64
}
prefix := name + "."
backups := make([]backupEntry, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() || !strings.HasPrefix(entry.Name(), prefix) {
continue
}
timestamp, err := strconv.ParseInt(strings.TrimPrefix(entry.Name(), prefix), 10, 64)
if err != nil {
continue
}
backups = append(backups, backupEntry{
name: entry.Name(),
timestamp: timestamp,
})
}
if len(backups) <= keep {
return
}
sort.Slice(backups, func(i, j int) bool {
if backups[i].timestamp != backups[j].timestamp {
return backups[i].timestamp > backups[j].timestamp
}
return backups[i].name > backups[j].name
})
for _, backup := range backups[keep:] {
_ = os.Remove(filepath.Join(dir, backup.name))
}
}

View File

@@ -0,0 +1,627 @@
package fileutil
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
func TestMain(m *testing.M) {
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
if err != nil {
panic(err)
}
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
panic(err)
}
if err := os.Setenv("HOME", tmpRoot); err != nil {
panic(err)
}
if err := os.Setenv("USERPROFILE", tmpRoot); err != nil {
panic(err)
}
code := m.Run()
_ = os.RemoveAll(tmpRoot)
os.Exit(code)
}
func mustMarshal(t *testing.T, v any) []byte {
t.Helper()
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return data
}
func isolatedTempDir(t *testing.T) string {
t.Helper()
return t.TempDir()
}
func TestWriteWithBackup(t *testing.T) {
tmpDir := isolatedTempDir(t)
t.Run("uses ollama directory under home", func(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
want := filepath.Join(home, ".ollama", "backup")
if got := BackupDir(); got != want {
t.Fatalf("BackupDir() = %q, want %q", got, want)
}
})
t.Run("creates file", func(t *testing.T) {
path := filepath.Join(tmpDir, "new.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var result map[string]string
if err := json.Unmarshal(content, &result); err != nil {
t.Fatal(err)
}
if result["key"] != "value" {
t.Errorf("expected value, got %s", result["key"])
}
})
t.Run("creates backup in the shared backup directory", func(t *testing.T) {
path := filepath.Join(tmpDir, "backup.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(BackupDir())
if err != nil {
t.Fatal("backup directory not created")
}
var foundBackup bool
for _, entry := range entries {
if filepath.Ext(entry.Name()) != ".json" {
name := entry.Name()
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
backupPath := filepath.Join(BackupDir(), name)
backup, err := os.ReadFile(backupPath)
if err == nil {
var backupData map[string]bool
json.Unmarshal(backup, &backupData)
if backupData["original"] {
foundBackup = true
os.Remove(backupPath)
break
}
}
}
}
}
if !foundBackup {
t.Error("backup file not created in backup directory")
}
current, _ := os.ReadFile(path)
var currentData map[string]bool
json.Unmarshal(current, &currentData)
if !currentData["updated"] {
t.Error("file doesn't contain updated data")
}
})
t.Run("stores hinted backups under a subdirectory", func(t *testing.T) {
path := filepath.Join(tmpDir, "hinted.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
data := mustMarshal(t, map[string]bool{"updated": true})
if err := WriteWithBackup(path, data, "openclaw"); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(filepath.Join(BackupDir(), "openclaw"))
if err != nil {
t.Fatal(err)
}
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("hinted.json.") && name[:len("hinted.json.")] == "hinted.json." {
found = true
_ = os.Remove(filepath.Join(BackupDir(), "openclaw", name))
break
}
}
if !found {
t.Error("backup file was not created under hint directory")
}
})
t.Run("no backup for new file", func(t *testing.T) {
path := filepath.Join(tmpDir, "nobak.json")
data := mustMarshal(t, map[string]string{"new": "file"})
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(BackupDir())
for _, entry := range entries {
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
t.Error("backup should not exist for new file")
}
}
})
t.Run("no backup when content unchanged", func(t *testing.T) {
path := filepath.Join(tmpDir, "unchanged.json")
data := mustMarshal(t, map[string]string{"key": "value"})
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries1, _ := os.ReadDir(BackupDir())
countBefore := 0
for _, e := range entries1 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countBefore++
}
}
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries2, _ := os.ReadDir(BackupDir())
countAfter := 0
for _, e := range entries2 {
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
countAfter++
}
}
if countAfter != countBefore {
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
}
})
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
path := filepath.Join(tmpDir, "timestamped.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
data := mustMarshal(t, map[string]int{"v": 2})
if err := WriteWithBackup(path, data); err != nil {
t.Fatal(err)
}
entries, _ := os.ReadDir(BackupDir())
var found bool
for _, entry := range entries {
name := entry.Name()
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
timestamp := name[len("timestamped.json."):]
for _, c := range timestamp {
if c < '0' || c > '9' {
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
}
}
found = true
os.Remove(filepath.Join(BackupDir(), name))
break
}
}
if !found {
t.Error("backup file with timestamp not found")
}
})
t.Run("retains only the five newest backups per file", func(t *testing.T) {
path := filepath.Join(tmpDir, "pruned.json")
if err := os.WriteFile(path, []byte(`{"v": 0}`), 0o644); err != nil {
t.Fatal(err)
}
for i := 1; i <= maxBackupsPerFile; i++ {
backupPath := filepath.Join(BackupDir(), fmt.Sprintf("pruned.json.%d", i))
if err := os.WriteFile(backupPath, []byte(fmt.Sprintf(`{"v": %d}`, i)), 0o644); err != nil {
t.Fatal(err)
}
}
if err := WriteWithBackup(path, []byte(`{"v": 1}`)); err != nil {
t.Fatal(err)
}
backups, err := filepath.Glob(filepath.Join(BackupDir(), "pruned.json.*"))
if err != nil {
t.Fatal(err)
}
if len(backups) != maxBackupsPerFile {
t.Fatalf("expected %d backups after pruning, got %d", maxBackupsPerFile, len(backups))
}
if _, err := os.Stat(filepath.Join(BackupDir(), "pruned.json.1")); !os.IsNotExist(err) {
t.Fatalf("expected oldest backup to be pruned, stat err = %v", err)
}
})
}
// Edge case tests for files.go
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
// User could lose their config with no way to recover.
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "config.json")
// Create original file
originalContent := []byte(`{"original": true}`)
os.WriteFile(path, originalContent, 0o644)
// Make backup directory read-only to force backup failure
backupDir := BackupDir()
os.MkdirAll(backupDir, 0o755)
os.Chmod(backupDir, 0o444) // Read-only
defer os.Chmod(backupDir, 0o755)
newContent := []byte(`{"updated": true}`)
err := WriteWithBackup(path, newContent)
// Should fail because backup couldn't be created
if err == nil {
t.Error("expected error when backup fails, got nil")
}
// Original file should be preserved
current, _ := os.ReadFile(path)
if string(current) != string(originalContent) {
t.Errorf("original file was modified despite backup failure: got %s", string(current))
}
}
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
// Common issue when config owned by root or wrong perms.
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
// Create a read-only directory
readOnlyDir := filepath.Join(tmpDir, "readonly")
os.MkdirAll(readOnlyDir, 0o755)
os.Chmod(readOnlyDir, 0o444)
defer os.Chmod(readOnlyDir, 0o755)
path := filepath.Join(readOnlyDir, "config.json")
err := WriteWithBackup(path, []byte(`{"test": true}`))
if err == nil {
t.Error("expected permission error, got nil")
}
}
func TestWriteWithBackup_UnchangedContentIsNoOp(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "unchanged-noop.json")
data := []byte(`{"same":true}`)
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatal(err)
}
if err := os.Chmod(tmpDir, 0o555); err != nil {
t.Fatal(err)
}
defer os.Chmod(tmpDir, 0o755)
if err := WriteWithBackup(path, data); err != nil {
t.Fatalf("expected unchanged write to be a no-op, got %v", err)
}
backups, err := filepath.Glob(filepath.Join(BackupDir(), "unchanged-noop.json.*"))
if err != nil {
t.Fatal(err)
}
if len(backups) != 0 {
t.Fatalf("expected no backups for unchanged content, got %d", len(backups))
}
}
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
// writeWithBackup doesn't create directories - caller is responsible.
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
err := WriteWithBackup(path, []byte(`{"test": true}`))
// Should fail because directory doesn't exist
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
// Documents what happens if user symlinks their config file.
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("symlink tests may require admin on Windows")
}
tmpDir := isolatedTempDir(t)
realFile := filepath.Join(tmpDir, "real.json")
symlink := filepath.Join(tmpDir, "link.json")
// Create real file and symlink
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
os.Symlink(realFile, symlink)
// Write through symlink
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
if err != nil {
t.Fatalf("writeWithBackup through symlink failed: %v", err)
}
// The real file should be updated (symlink followed for temp file creation)
content, _ := os.ReadFile(symlink)
if string(content) != `{"v": 2}` {
t.Errorf("symlink target not updated correctly: got %s", string(content))
}
}
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
// User may have config files with unusual names.
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
tmpDir := isolatedTempDir(t)
// File with spaces and special chars
path := filepath.Join(tmpDir, "my config (backup).json")
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
backupPath, err := writeBackupCopy(path, "")
if err != nil {
t.Fatalf("writeBackupCopy with special chars failed: %v", err)
}
// Verify backup exists and has correct content
content, err := os.ReadFile(backupPath)
if err != nil {
t.Fatalf("could not read backup: %v", err)
}
if string(content) != `{"test": true}` {
t.Errorf("backup content mismatch: got %s", string(content))
}
os.Remove(backupPath)
}
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
func TestCopyFile_PreservesPermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission preservation tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "src.json")
dst := filepath.Join(tmpDir, "dst.json")
// Create source with specific permissions
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
err := copyFile(src, dst)
if err != nil {
t.Fatalf("copyFile failed: %v", err)
}
srcInfo, _ := os.Stat(src)
dstInfo, _ := os.Stat(dst)
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
}
}
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
func TestCopyFile_SourceNotFound(t *testing.T) {
tmpDir := isolatedTempDir(t)
src := filepath.Join(tmpDir, "nonexistent.json")
dst := filepath.Join(tmpDir, "dst.json")
err := copyFile(src, dst)
if err == nil {
t.Error("expected error for nonexistent source, got nil")
}
}
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
tmpDir := isolatedTempDir(t)
dirPath := filepath.Join(tmpDir, "actualdir")
os.MkdirAll(dirPath, 0o755)
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
if err == nil {
t.Error("expected error when target is a directory, got nil")
}
}
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
func TestWriteWithBackup_EmptyData(t *testing.T) {
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "empty.json")
err := WriteWithBackup(path, []byte{})
if err != nil {
t.Fatalf("writeWithBackup with empty data failed: %v", err)
}
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("could not read file: %v", err)
}
if len(content) != 0 {
t.Errorf("expected empty file, got %d bytes", len(content))
}
}
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
// cannot be read (for backup comparison) but directory is writable.
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "unreadable.json")
// Create file and make it unreadable
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
os.Chmod(path, 0o000)
defer os.Chmod(path, 0o644)
// Should fail because we can't read the file to compare/backup
err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when file is unreadable, got nil")
}
}
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
// within the same second (timestamp collision scenario).
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
tmpDir := isolatedTempDir(t)
path := filepath.Join(tmpDir, "rapid.json")
// Create initial file
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
// Rapid successive writes
for i := 1; i <= 3; i++ {
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
if err := WriteWithBackup(path, data); err != nil {
t.Fatalf("write %d failed: %v", i, err)
}
}
// Verify final content
content, _ := os.ReadFile(path)
if string(content) != `{"v": 3}` {
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
}
// Verify at least one backup exists
entries, _ := os.ReadDir(BackupDir())
var backupCount int
for _, e := range entries {
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
backupCount++
}
}
if backupCount == 0 {
t.Error("expected at least one backup file from rapid writes")
}
}
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("test modifies system temp directory")
}
tmpDir := isolatedTempDir(t)
// Create a file at the backup directory path
backupPath := BackupDir()
// Clean up any existing directory first
os.RemoveAll(backupPath)
// Create a file instead of directory
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
defer func() {
os.Remove(backupPath)
os.MkdirAll(backupPath, 0o755)
}()
path := filepath.Join(tmpDir, "test.json")
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
err := WriteWithBackup(path, []byte(`{"updated": true}`))
if err == nil {
t.Error("expected error when backup dir is a file, got nil")
}
}
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission tests unreliable on Windows")
}
tmpDir := isolatedTempDir(t)
// Count existing temp files
countTempFiles := func() int {
entries, _ := os.ReadDir(tmpDir)
count := 0
for _, e := range entries {
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
count++
}
}
return count
}
before := countTempFiles()
// Create a file, then make directory read-only to cause rename failure
path := filepath.Join(tmpDir, "orphan.json")
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
// Make a subdirectory and try to write there after making parent read-only
subDir := filepath.Join(tmpDir, "subdir")
os.MkdirAll(subDir, 0o755)
subPath := filepath.Join(subDir, "config.json")
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
// Make subdir read-only after creating temp file would succeed but rename would fail
// This is tricky to test - the temp file is created in the same dir, so if we can't
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
// Force a failure by making the target a directory
badPath := filepath.Join(tmpDir, "isdir")
os.MkdirAll(badPath, 0o755)
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
after := countTempFiles()
if after > before {
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
}
}

371
cmd/launch/account.go Normal file
View File

@@ -0,0 +1,371 @@
package launch
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/ollama/ollama/api"
)
const (
// DefaultUpgradeURL is the fixed destination for subscription upgrades.
DefaultUpgradeURL = "https://ollama.com/upgrade"
accountCheckTimeout = 3 * time.Second
)
var (
ErrPlanVerificationUnavailable = errors.New("Could not verify your plan. Try again in a moment.")
errUpgradeCancelled = errors.New("upgrade cancelled")
)
type accountStateStatus int
const (
accountStateUnknown accountStateStatus = iota
accountStateSignedOut
accountStateSignedIn
)
type AccountState struct {
Status accountStateStatus
Plan string
}
type AccountStatePrefetch struct {
done chan struct{}
state AccountState
}
func StartAccountStatePrefetch(ctx context.Context) *AccountStatePrefetch {
if ctx == nil {
ctx = context.Background()
}
p := &AccountStatePrefetch{done: make(chan struct{})}
go func() {
state := AccountState{Status: accountStateUnknown}
client, err := api.ClientFromEnvironment()
if err == nil {
prefetchCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
defer cancel()
if disabled, known := cloudStatusDisabled(prefetchCtx, client); !known || !disabled {
state = launchAccountState(prefetchCtx, client)
}
}
p.state = state
close(p.done)
}()
return p
}
func (p *AccountStatePrefetch) StateIfReady() *AccountState {
if p == nil {
return nil
}
select {
case <-p.done:
state := p.state
return &state
default:
return nil
}
}
func (p *AccountStatePrefetch) StateUpdates(ctx context.Context) <-chan *AccountState {
if p == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
out := make(chan *AccountState, 1)
go func() {
defer close(out)
select {
case <-p.done:
if p.state.Status == accountStateUnknown {
return
}
state := p.state
select {
case out <- &state:
case <-ctx.Done():
}
case <-ctx.Done():
}
}()
return out
}
func launchAccountState(ctx context.Context, client *api.Client) AccountState {
if client == nil {
return AccountState{Status: accountStateUnknown}
}
user, err := whoamiWithTimeout(ctx, client)
if err != nil {
var authErr api.AuthorizationError
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
return AccountState{Status: accountStateSignedOut}
}
return AccountState{Status: accountStateUnknown}
}
if user == nil || strings.TrimSpace(user.Name) == "" {
return AccountState{Status: accountStateSignedOut}
}
return AccountState{
Status: accountStateSignedIn,
Plan: strings.TrimSpace(user.Plan),
}
}
func whoamiWithTimeout(ctx context.Context, client *api.Client) (*api.UserResponse, error) {
if ctx == nil {
ctx = context.Background()
}
checkCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
defer cancel()
return client.Whoami(checkCtx)
}
func ApplyAccountStateToSelectionItems(items []ModelItem, state AccountState) []SelectionItem {
out := make([]SelectionItem, len(items))
for i, item := range items {
out[i] = SelectionItem{
Name: item.Name,
Description: item.Description,
Recommended: item.Recommended,
AvailabilityBadge: availabilityBadge(item, state),
}
}
return out
}
func SelectionItemsWithAccountState(items []ModelItem, state *AccountState) []SelectionItem {
if state == nil || !selectionItemsNeedAccountState(items) {
return ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
}
return ApplyAccountStateToSelectionItems(items, *state)
}
func selectionItemsNeedAccountState(items []ModelItem) bool {
for _, item := range items {
if isCloudModelName(item.Name) && itemHasRecommendationMetadata(item) {
return true
}
}
return false
}
func (c *launcherClient) selectionItemUpdates(ctx context.Context, items []ModelItem, state *AccountState) <-chan []SelectionItem {
if !selectionItemsNeedAccountState(items) || state != nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
stateUpdates := c.accountStateUpdateSource(ctx)
if stateUpdates == nil {
return nil
}
out := make(chan []SelectionItem, 1)
go func() {
defer close(out)
select {
case state, ok := <-stateUpdates:
if !ok || state == nil {
return
}
select {
case out <- SelectionItemsWithAccountState(items, state):
case <-ctx.Done():
}
case <-ctx.Done():
}
}()
return out
}
func (c *launcherClient) accountStateUpdateSource(ctx context.Context) <-chan *AccountState {
if c.accountStateUpdates != nil {
return c.accountStateUpdates(ctx)
}
if c.apiClient == nil {
return nil
}
out := make(chan *AccountState, 1)
go func() {
defer close(out)
state := launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
return
}
select {
case out <- &state:
case <-ctx.Done():
}
}()
return out
}
func availabilityBadge(item ModelItem, state AccountState) string {
if !isCloudModelName(item.Name) {
return ""
}
switch state.Status {
case accountStateSignedOut:
if itemHasRecommendationMetadata(item) {
return "Sign in required"
}
case accountStateSignedIn:
if item.RequiredPlan != "" && !PlanSatisfies(state.Plan, item.RequiredPlan) {
return "Upgrade required"
}
}
return ""
}
func itemHasRecommendationMetadata(item ModelItem) bool {
return item.Recommended || strings.TrimSpace(item.RequiredPlan) != ""
}
func (c *launcherClient) ensureCloudModelAccess(ctx context.Context, model string) error {
item, ok := c.modelRecommendationItem(ctx, model)
if !ok || strings.TrimSpace(item.RequiredPlan) == "" {
return nil
}
state := launchAccountState(ctx, c.apiClient)
if state.Status != accountStateUnknown {
c.accountState = &state
}
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
if state.Status == accountStateSignedOut {
if err := ensureCloudAuth(ctx, c.apiClient, model); err != nil {
return err
}
state = launchAccountState(ctx, c.apiClient)
if state.Status != accountStateUnknown {
c.accountState = &state
}
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
}
if PlanSatisfies(state.Plan, item.RequiredPlan) {
return nil
}
if err := c.runUpgradeFlow(ctx, item); err != nil {
return err
}
state = launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
return ErrPlanVerificationUnavailable
}
if state.Status != accountStateSignedIn || !PlanSatisfies(state.Plan, item.RequiredPlan) {
return errUpgradeCancelled
}
return nil
}
func (c *launcherClient) modelRecommendationItem(ctx context.Context, model string) (ModelItem, bool) {
for _, item := range c.recommendations(ctx) {
if item.Name == model {
return item, true
}
}
return ModelItem{}, false
}
func (c *launcherClient) runUpgradeFlow(ctx context.Context, item ModelItem) error {
if DefaultUpgrade != nil {
if _, err := DefaultUpgrade(item.Name, item.RequiredPlan); err != nil {
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
return err
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("Upgrade to use %s?", item.Name))
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
if err != nil {
return err
}
if !yes {
return errUpgradeCancelled
}
fmt.Fprintf(os.Stderr, "\nTo upgrade, navigate to:\n %s\n\n", DefaultUpgradeURL)
openNow, err := ConfirmPrompt("Open now?")
if errors.Is(err, ErrCancelled) {
return errUpgradeCancelled
}
if err != nil {
return err
}
if openNow {
OpenBrowser(DefaultUpgradeURL)
} else {
return errUpgradeCancelled
}
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 != 0 {
continue
}
state := launchAccountState(ctx, c.apiClient)
if state.Status == accountStateUnknown {
fmt.Fprintf(os.Stderr, "\r\033[K")
return ErrPlanVerificationUnavailable
}
if state.Status == accountStateSignedIn && PlanSatisfies(state.Plan, item.RequiredPlan) {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1mplan updated\033[0m\n")
return nil
}
}
}
}
// PlanSatisfies reports whether currentPlan can use a model that has a requiredPlan.
func PlanSatisfies(currentPlan, requiredPlan string) bool {
required := normalizePlan(requiredPlan)
if required == "" || required == "free" {
return true
}
current := normalizePlan(currentPlan)
return current != "" && current != "free"
}
func normalizePlan(plan string) string {
return strings.ToLower(strings.TrimSpace(plan))
}

87
cmd/launch/claude.go Normal file
View File

@@ -0,0 +1,87 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/ollama/ollama/envconfig"
)
// Claude implements Runner for Claude Code integration.
type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string, _ []LaunchModel, args []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(),
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
)
env = append(env, c.modelEnvVars(model)...)
cmd.Env = env
return cmd.Run()
}
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
func (c *Claude) modelEnvVars(model string) []string {
env := []string{
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
}
}
return env
}

View File

@@ -0,0 +1,888 @@
package launch
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
"golang.org/x/term"
)
const (
claudeDesktopIntegrationName = "claude-desktop"
claudeDesktopProfileName = "Ollama"
claudeDesktopProfileID = "00000000-0000-4000-8000-000000000114"
claudeDesktopGatewayBaseURL = "https://ollama.com"
claudeDesktopAPIKeyURL = "https://ollama.com/settings/keys"
claudeDesktopModelLabel = "Ollama Cloud"
claudeDesktopUnsupported = "Claude Desktop is no longer supported. Existing installations can be restored with 'ollama launch claude-desktop --restore'."
claudeDesktopSuccessMessage = "Claude Desktop profile changed to Ollama Cloud."
claudeDesktopRestoreMessage = "To restore the usual Claude profile, run: ollama launch claude-desktop --restore"
claudeDesktopRestoredMessage = "Claude Desktop restored to the usual Claude profile."
)
var (
claudeDesktopGOOS = runtime.GOOS
claudeDesktopUserHome = os.UserHomeDir
claudeDesktopStat = os.Stat
claudeDesktopOpenApp = defaultClaudeDesktopOpenApp
claudeDesktopOpenAppPath = defaultClaudeDesktopOpenAppPath
claudeDesktopQuitApp = defaultClaudeDesktopQuitApp
claudeDesktopIsRunning = defaultClaudeDesktopIsRunning
claudeDesktopRunningAppPath = defaultClaudeDesktopRunningAppPath
claudeDesktopGlob = filepath.Glob
claudeDesktopSleep = time.Sleep
claudeDesktopHTTPClient = http.DefaultClient
claudeDesktopPromptAPIKey = promptClaudeDesktopAPIKey
claudeDesktopValidateAPIKey = validateClaudeDesktopAPIKey
)
// ClaudeDesktop configures and launches Claude Desktop in third-party
// inference mode using Ollama Cloud as the gateway.
type ClaudeDesktop struct{}
func (c *ClaudeDesktop) String() string { return "Claude Desktop" }
func (c *ClaudeDesktop) Supported() error { return claudeDesktopSupported() }
func (c *ClaudeDesktop) Paths() []string {
return nil
}
func (c *ClaudeDesktop) AutodiscoveredModel() string {
return claudeDesktopModelLabel
}
func (c *ClaudeDesktop) ConfigureAutodiscovery() error {
if err := claudeDesktopSupported(); err != nil {
return err
}
targets, err := claudeDesktopTargetPaths()
if err != nil {
return err
}
key, err := claudeDesktopValidatedAPIKey(context.Background(), claudeDesktopTargetProfilePaths(targets))
if err != nil {
return err
}
for _, path := range targets.normalConfigs {
if err := writeClaudeDesktopDeploymentMode(path, "3p"); err != nil {
return err
}
}
for _, target := range targets.thirdPartyProfiles {
if err := writeClaudeDesktopDeploymentMode(target.desktopConfig, "3p"); err != nil {
return err
}
if err := writeClaudeDesktopMeta(target.meta, claudeDesktopProfileID, claudeDesktopProfileName); err != nil {
return err
}
if err := writeClaudeDesktopGatewayProfile(target.profile, key, true); err != nil {
return err
}
}
return nil
}
func (c *ClaudeDesktop) RestoreHint() string {
return claudeDesktopRestoreMessage
}
func (c *ClaudeDesktop) ConfigurationSuccessMessage() string {
return claudeDesktopSuccessMessage + "\n" + claudeDesktopRestoreMessage
}
func (c *ClaudeDesktop) RestoreSuccessMessage() string {
return claudeDesktopRestoredMessage
}
func (c *ClaudeDesktop) AutodiscoveryConfigured() bool {
targets, err := claudeDesktopTargetPaths()
if err != nil {
return false
}
return claudeDesktopTargetsConfigured(targets)
}
func (c *ClaudeDesktop) Onboard() error {
return config.MarkIntegrationOnboarded(claudeDesktopIntegrationName)
}
func (c *ClaudeDesktop) RequiresInteractiveOnboarding() bool {
return false
}
func (c *ClaudeDesktop) SkipModelReadiness() bool {
return true
}
func (c *ClaudeDesktop) Run(_ string, _ []LaunchModel, _ []string) error {
return errClaudeDesktopUnsupported()
}
func (c *ClaudeDesktop) Restore() error {
if err := claudeDesktopSupported(); err != nil {
return err
}
targets, err := claudeDesktopTargetPaths()
if err != nil {
return err
}
for _, path := range targets.normalConfigs {
if err := writeClaudeDesktopDeploymentMode(path, "1p"); err != nil {
return err
}
}
for _, target := range targets.thirdPartyProfiles {
if err := writeClaudeDesktopDeploymentMode(target.desktopConfig, "1p"); err != nil {
return err
}
if err := restoreClaudeDesktopMeta(target.meta); err != nil {
return err
}
if err := restoreClaudeDesktopOllamaProfile(target.profile); err != nil {
return err
}
}
return claudeDesktopLaunchOrRestart("Restart Claude Desktop to use the usual Claude profile?")
}
func errClaudeDesktopUnsupported() error {
return errors.New(claudeDesktopUnsupported)
}
func claudeDesktopSupported() error {
switch claudeDesktopGOOS {
case "darwin", "windows":
return nil
default:
return fmt.Errorf("Claude Desktop launch is only supported on macOS and Windows")
}
}
func claudeDesktopInstalled() bool {
if claudeDesktopAppPath() != "" {
return true
}
if claudeDesktopGOOS == "windows" && claudeDesktopIsRunning() {
return true
}
for _, dir := range claudeDesktopProfileDirCandidates(false) {
if _, err := claudeDesktopStat(dir); err == nil {
return true
}
}
return false
}
func claudeDesktopAppPath() string {
if claudeDesktopGOOS != "darwin" && claudeDesktopGOOS != "windows" {
return ""
}
for _, path := range claudeDesktopAppCandidates() {
if _, err := claudeDesktopStat(path); err == nil {
return path
}
}
return ""
}
func claudeDesktopAppCandidates() []string {
switch claudeDesktopGOOS {
case "darwin":
return claudeDesktopDarwinAppCandidates()
case "windows":
return claudeDesktopWindowsAppCandidates()
default:
return nil
}
}
func claudeDesktopDarwinAppCandidates() []string {
candidates := []string{"/Applications/Claude.app"}
if home, err := claudeDesktopUserHome(); err == nil {
candidates = append(candidates, filepath.Join(home, "Applications", "Claude.app"))
}
return candidates
}
func claudeDesktopWindowsAppCandidates() []string {
local, err := claudeDesktopLocalAppData()
if err != nil {
return nil
}
candidates := []string{
filepath.Join(local, "Programs", "Claude", "Claude.exe"),
filepath.Join(local, "Programs", "Claude Desktop", "Claude.exe"),
filepath.Join(local, "Claude", "Claude.exe"),
filepath.Join(local, "Claude Nest", "Claude.exe"),
filepath.Join(local, "Claude Desktop", "Claude.exe"),
filepath.Join(local, "AnthropicClaude", "Claude.exe"),
}
for _, pattern := range []string{
filepath.Join(local, "AnthropicClaude", "app-*", "Claude.exe"),
filepath.Join(local, "Programs", "Claude", "app-*", "Claude.exe"),
filepath.Join(local, "Programs", "Claude Desktop", "app-*", "Claude.exe"),
} {
matches, _ := claudeDesktopGlob(pattern)
candidates = append(candidates, matches...)
}
return claudeDesktopDedupePaths(candidates)
}
func claudeDesktopDedupePaths(paths []string) []string {
out := make([]string, 0, len(paths))
seen := make(map[string]bool, len(paths))
for _, path := range paths {
if strings.TrimSpace(path) == "" {
continue
}
key := strings.ToLower(path)
if seen[key] {
continue
}
seen[key] = true
out = append(out, path)
}
return out
}
type claudeDesktopPaths struct {
normalConfig string
desktopConfig string
meta string
profile string
}
type claudeDesktopThirdPartyPaths struct {
desktopConfig string
meta string
profile string
}
type claudeDesktopTargets struct {
normalConfigs []string
thirdPartyProfiles []claudeDesktopThirdPartyPaths
}
func claudeDesktopConfigPaths() (claudeDesktopPaths, error) {
switch claudeDesktopGOOS {
case "darwin":
return claudeDesktopDarwinConfigPaths()
case "windows":
return claudeDesktopWindowsConfigPaths()
default:
return claudeDesktopPaths{}, claudeDesktopSupported()
}
}
func claudeDesktopDarwinConfigPaths() (claudeDesktopPaths, error) {
normalRoots, thirdPartyRoots, err := claudeDesktopDarwinProfileRoots()
if err != nil {
return claudeDesktopPaths{}, err
}
normalBase := normalRoots[0]
thirdPartyBase := thirdPartyRoots[0]
return claudeDesktopPaths{
normalConfig: filepath.Join(normalBase, "claude_desktop_config.json"),
desktopConfig: filepath.Join(thirdPartyBase, "claude_desktop_config.json"),
meta: filepath.Join(thirdPartyBase, "configLibrary", "_meta.json"),
profile: filepath.Join(thirdPartyBase, "configLibrary", claudeDesktopProfileID+".json"),
}, nil
}
func claudeDesktopWindowsConfigPaths() (claudeDesktopPaths, error) {
normalBase, err := claudeDesktopProfileDir(true)
if err != nil {
return claudeDesktopPaths{}, err
}
thirdPartyBase, err := claudeDesktopProfileDir(false)
if err != nil {
return claudeDesktopPaths{}, err
}
return claudeDesktopPaths{
normalConfig: filepath.Join(normalBase, "claude_desktop_config.json"),
desktopConfig: filepath.Join(thirdPartyBase, "claude_desktop_config.json"),
meta: filepath.Join(thirdPartyBase, "configLibrary", "_meta.json"),
profile: filepath.Join(thirdPartyBase, "configLibrary", claudeDesktopProfileID+".json"),
}, nil
}
func claudeDesktopProfileDir(normal bool) (string, error) {
candidates := claudeDesktopProfileDirCandidates(normal)
if len(candidates) == 0 {
return "", fmt.Errorf("Claude Desktop profile directory could not be resolved")
}
for _, candidate := range candidates {
if _, err := claudeDesktopStat(candidate); err == nil {
return candidate, nil
}
}
return candidates[0], nil
}
func claudeDesktopProfileDirCandidates(normal bool) []string {
if claudeDesktopGOOS != "windows" {
return nil
}
normalRoots, thirdPartyRoots, err := claudeDesktopWindowsProfileRoots()
if err != nil {
return nil
}
if normal {
return normalRoots
}
return thirdPartyRoots
}
func claudeDesktopDarwinProfileRoots() ([]string, []string, error) {
home, err := claudeDesktopUserHome()
if err != nil {
return nil, nil, err
}
base := filepath.Join(home, "Library", "Application Support")
return []string{filepath.Join(base, "Claude")}, []string{filepath.Join(base, "Claude-3p")}, nil
}
func claudeDesktopWindowsProfileRoots() ([]string, []string, error) {
local, err := claudeDesktopLocalAppData()
if err != nil {
return nil, nil, err
}
normalRoots := []string{
filepath.Join(local, "Claude"),
filepath.Join(local, "Claude Nest"),
}
thirdPartyRoots := []string{
filepath.Join(local, "Claude-3p"),
filepath.Join(local, "Claude Nest-3p"),
}
return normalRoots, thirdPartyRoots, nil
}
func claudeDesktopTargetPaths() (claudeDesktopTargets, error) {
var (
normalRoots []string
thirdPartyRoots []string
err error
)
switch claudeDesktopGOOS {
case "darwin":
normalRoots, thirdPartyRoots, err = claudeDesktopDarwinProfileRoots()
case "windows":
normalRoots, thirdPartyRoots, err = claudeDesktopWindowsProfileRoots()
default:
err = claudeDesktopSupported()
}
if err != nil {
return claudeDesktopTargets{}, err
}
return newClaudeDesktopTargets(normalRoots, thirdPartyRoots), nil
}
func newClaudeDesktopTargets(normalRoots, thirdPartyRoots []string) claudeDesktopTargets {
targets := claudeDesktopTargets{}
for _, root := range claudeDesktopDedupePaths(normalRoots) {
targets.normalConfigs = append(targets.normalConfigs, filepath.Join(root, "claude_desktop_config.json"))
}
for _, root := range claudeDesktopDedupePaths(thirdPartyRoots) {
targets.thirdPartyProfiles = append(targets.thirdPartyProfiles, claudeDesktopThirdPartyPaths{
desktopConfig: filepath.Join(root, "claude_desktop_config.json"),
meta: filepath.Join(root, "configLibrary", "_meta.json"),
profile: filepath.Join(root, "configLibrary", claudeDesktopProfileID+".json"),
})
}
return targets
}
func claudeDesktopTargetProfilePaths(targets claudeDesktopTargets) []string {
paths := make([]string, 0, len(targets.thirdPartyProfiles))
for _, target := range targets.thirdPartyProfiles {
paths = append(paths, target.profile)
}
return paths
}
func claudeDesktopLocalAppData() (string, error) {
if local := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); local != "" {
return local, nil
}
if home := strings.TrimSpace(os.Getenv("USERPROFILE")); home != "" {
return filepath.Join(home, "AppData", "Local"), nil
}
home, err := claudeDesktopUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, "AppData", "Local"), nil
}
type claudeDesktopAPIKeySource int
const (
claudeDesktopAPIKeySourceNone claudeDesktopAPIKeySource = iota
claudeDesktopAPIKeySourceEnv
claudeDesktopAPIKeySourceProfile
)
func claudeDesktopValidatedAPIKey(ctx context.Context, profilePaths []string) (string, error) {
key, source, err := claudeDesktopAPIKey(profilePaths)
if err != nil {
return "", err
}
if err := claudeDesktopValidateAPIKey(ctx, key); err == nil {
return key, nil
} else if source != claudeDesktopAPIKeySourceProfile || !canPromptClaudeDesktopAPIKey() {
return "", err
}
return promptValidClaudeDesktopAPIKey(ctx)
}
func claudeDesktopAPIKey(profilePaths []string) (string, claudeDesktopAPIKeySource, error) {
if key := strings.TrimSpace(os.Getenv("OLLAMA_API_KEY")); key != "" {
return key, claudeDesktopAPIKeySourceEnv, nil
}
for _, profilePath := range profilePaths {
if key := readClaudeDesktopGatewayAPIKey(profilePath); key != "" {
return key, claudeDesktopAPIKeySourceProfile, nil
}
}
key, err := promptClaudeDesktopAPIKeyValue()
return key, claudeDesktopAPIKeySourceNone, err
}
func canPromptClaudeDesktopAPIKey() bool {
return isInteractiveSession() && !currentLaunchConfirmPolicy.requireYesMessage
}
func promptValidClaudeDesktopAPIKey(ctx context.Context) (string, error) {
key, err := promptClaudeDesktopAPIKeyValue()
if err != nil {
return "", err
}
if err := claudeDesktopValidateAPIKey(ctx, key); err != nil {
return "", err
}
return key, nil
}
func promptClaudeDesktopAPIKeyValue() (string, error) {
if !canPromptClaudeDesktopAPIKey() {
return "", missingClaudeDesktopAPIKeyError()
}
key, err := claudeDesktopPromptAPIKey()
if err != nil {
return "", err
}
key = strings.TrimSpace(key)
if key == "" {
return "", missingClaudeDesktopAPIKeyError()
}
return key, nil
}
func missingClaudeDesktopAPIKeyError() error {
return fmt.Errorf("OLLAMA_API_KEY is required for Claude Desktop. Create an API key at %s, then re-run with OLLAMA_API_KEY set", claudeDesktopAPIKeyURL)
}
func promptClaudeDesktopAPIKey() (string, error) {
fmt.Fprint(os.Stderr, claudeDesktopAPIKeyPrompt())
key, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Fprintln(os.Stderr)
if err != nil {
return "", err
}
return string(key), nil
}
func claudeDesktopAPIKeyPrompt() string {
return fmt.Sprintf("Create an Ollama API key at %s\nEnter Ollama API key (input hidden): ", claudeDesktopAPIKeyURL)
}
func readClaudeDesktopGatewayAPIKey(path string) string {
cfg, err := readClaudeDesktopJSON(path)
if err != nil {
return ""
}
key, _ := cfg["inferenceGatewayApiKey"].(string)
return strings.TrimSpace(key)
}
func validateClaudeDesktopAPIKey(ctx context.Context, key string) error {
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
if claudeDesktopAPIKeyHasInvalidHeaderChars(key) {
return claudeDesktopAPIKeyVerificationError()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeDesktopGatewayBaseURL+"/v1/models", nil)
if err != nil {
return claudeDesktopAPIKeyVerificationError()
}
req.Header.Set("Authorization", "Bearer "+key)
req.Header.Set("Accept", "application/json")
resp, err := claudeDesktopHTTPClient.Do(req)
if err != nil {
return claudeDesktopAPIKeyVerificationError()
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10))
switch {
case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden:
return fmt.Errorf("Ollama API key was rejected; create a valid key at %s", claudeDesktopAPIKeyURL)
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return nil
default:
return fmt.Errorf("could not verify Ollama API key; ollama.com returned status %d, try again later", resp.StatusCode)
}
}
func claudeDesktopAPIKeyHasInvalidHeaderChars(key string) bool {
return strings.ContainsFunc(key, func(r rune) bool {
return r < ' ' || r == 0x7f
})
}
func claudeDesktopAPIKeyVerificationError() error {
return fmt.Errorf("could not verify Ollama API key; copy a key from %s and try again", claudeDesktopAPIKeyURL)
}
func writeClaudeDesktopDeploymentMode(path, mode string) error {
cfg, err := readClaudeDesktopJSONAllowMissing(path)
if err != nil {
return fmt.Errorf("parse Claude Desktop config: %w", err)
}
cfg["deploymentMode"] = mode
return writeClaudeDesktopJSON(path, cfg)
}
func writeClaudeDesktopMeta(path, id, name string) error {
meta, err := readClaudeDesktopJSONAllowMissing(path)
if err != nil {
return fmt.Errorf("parse Claude Desktop config metadata: %w", err)
}
meta["appliedId"] = id
entries := make([]any, 0)
for _, entry := range claudeDesktopAnySlice(meta["entries"]) {
entryMap, _ := entry.(map[string]any)
if entryMap == nil {
entries = append(entries, entry)
continue
}
if entryID, _ := entryMap["id"].(string); entryID == id {
continue
}
entries = append(entries, entryMap)
}
entries = append(entries, map[string]any{
"id": id,
"name": name,
})
meta["entries"] = entries
return writeClaudeDesktopJSON(path, meta)
}
func writeClaudeDesktopGatewayProfile(path string, apiKey string, forceChooser bool) error {
cfg, err := readClaudeDesktopJSONAllowMissing(path)
if err != nil {
return fmt.Errorf("parse Claude Desktop Ollama profile: %w", err)
}
cfg["inferenceProvider"] = "gateway"
cfg["inferenceGatewayBaseUrl"] = claudeDesktopGatewayBaseURL
cfg["inferenceGatewayApiKey"] = apiKey
cfg["inferenceGatewayAuthScheme"] = "bearer"
delete(cfg, "inferenceModels")
cfg["disableDeploymentModeChooser"] = forceChooser
return writeClaudeDesktopJSON(path, cfg)
}
func restoreClaudeDesktopMeta(path string) error {
meta, err := readClaudeDesktopJSONAllowMissing(path)
if err != nil {
return fmt.Errorf("parse Claude Desktop config metadata: %w", err)
}
if len(meta) == 0 {
return nil
}
changed := false
if appliedID, _ := meta["appliedId"].(string); appliedID == claudeDesktopProfileID {
delete(meta, "appliedId")
changed = true
}
entries := claudeDesktopAnySlice(meta["entries"])
if entries != nil {
filtered := make([]any, 0, len(entries))
for _, entry := range entries {
entryMap, _ := entry.(map[string]any)
if entryID, _ := entryMap["id"].(string); entryID == claudeDesktopProfileID {
changed = true
continue
}
filtered = append(filtered, entry)
}
meta["entries"] = filtered
}
if !changed {
return nil
}
return writeClaudeDesktopJSON(path, meta)
}
func restoreClaudeDesktopOllamaProfile(path string) error {
cfg, err := readClaudeDesktopJSONAllowMissing(path)
if err != nil {
return fmt.Errorf("parse Claude Desktop Ollama profile: %w", err)
}
if len(cfg) == 0 {
return nil
}
cfg["disableDeploymentModeChooser"] = false
delete(cfg, "inferenceProvider")
delete(cfg, "inferenceGatewayBaseUrl")
delete(cfg, "inferenceGatewayAuthScheme")
delete(cfg, "inferenceModels")
return writeClaudeDesktopJSON(path, cfg)
}
func readClaudeDesktopAppliedID(path string) string {
meta, err := readClaudeDesktopJSON(path)
if err != nil {
return ""
}
applied, _ := meta["appliedId"].(string)
return applied
}
func readClaudeDesktopDeploymentMode(path string) string {
cfg, err := readClaudeDesktopJSON(path)
if err != nil {
return ""
}
mode, _ := cfg["deploymentMode"].(string)
return mode
}
func claudeDesktopTargetsConfigured(targets claudeDesktopTargets) bool {
if len(targets.normalConfigs) == 0 || len(targets.thirdPartyProfiles) == 0 {
return false
}
for _, path := range targets.normalConfigs {
if readClaudeDesktopDeploymentMode(path) != "3p" {
return false
}
}
for _, target := range targets.thirdPartyProfiles {
if readClaudeDesktopDeploymentMode(target.desktopConfig) != "3p" {
return false
}
if !claudeDesktopThirdPartyProfileConfigured(target) {
return false
}
}
return true
}
func claudeDesktopThirdPartyProfileConfigured(target claudeDesktopThirdPartyPaths) bool {
if readClaudeDesktopAppliedID(target.meta) != claudeDesktopProfileID {
return false
}
cfg, err := readClaudeDesktopJSON(target.profile)
if err != nil {
return false
}
if s, _ := cfg["inferenceProvider"].(string); s != "gateway" {
return false
}
if s, _ := cfg["inferenceGatewayBaseUrl"].(string); strings.TrimRight(s, "/") != claudeDesktopGatewayBaseURL {
return false
}
if s, _ := cfg["inferenceGatewayApiKey"].(string); strings.TrimSpace(s) == "" {
return false
}
return true
}
func readClaudeDesktopJSONAllowMissing(path string) (map[string]any, error) {
cfg, err := readClaudeDesktopJSON(path)
if errors.Is(err, os.ErrNotExist) {
return map[string]any{}, nil
}
return cfg, err
}
func readClaudeDesktopJSON(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
if cfg == nil {
cfg = map[string]any{}
}
return cfg, nil
}
func writeClaudeDesktopJSON(path string, cfg any) error {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
data = append(data, '\n')
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(path, data)
}
func claudeDesktopAnySlice(value any) []any {
switch v := value.(type) {
case []any:
return v
case nil:
return nil
default:
return nil
}
}
func claudeDesktopLaunchOrRestart(prompt string) error {
if !claudeDesktopIsRunning() {
return claudeDesktopOpenApp()
}
restartAppPath := ""
if claudeDesktopGOOS == "windows" {
restartAppPath = claudeDesktopRunningAppPath()
}
restart, err := ConfirmPrompt(prompt)
if err != nil {
return err
}
if !restart {
fmt.Fprintln(os.Stderr, "\nQuit and reopen Claude Desktop when you're ready for the profile change to take effect.")
return nil
}
if err := claudeDesktopQuitApp(); err != nil {
return fmt.Errorf("quit Claude Desktop: %w", err)
}
if err := waitForClaudeDesktopExit(30 * time.Second); err != nil {
return err
}
if restartAppPath != "" {
return claudeDesktopOpenAppPath(restartAppPath)
}
return claudeDesktopOpenApp()
}
func waitForClaudeDesktopExit(timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if !claudeDesktopIsRunning() {
return nil
}
claudeDesktopSleep(200 * time.Millisecond)
}
return fmt.Errorf("Claude Desktop did not quit; quit it manually and re-run the command")
}
func defaultClaudeDesktopIsRunning() bool {
switch claudeDesktopGOOS {
case "darwin":
out, err := exec.Command("pgrep", "-f", "Claude.app/Contents/MacOS/Claude").Output()
return err == nil && strings.TrimSpace(string(out)) != ""
case "windows":
out, err := exec.Command("powershell.exe", "-NoProfile", "-Command", `(Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 } | Select-Object -First 1).Id`).Output()
return err == nil && strings.TrimSpace(string(out)) != ""
default:
return false
}
}
func defaultClaudeDesktopOpenApp() error {
switch claudeDesktopGOOS {
case "windows":
if path := claudeDesktopAppPath(); path != "" {
return claudeDesktopOpenAppPath(path)
}
if path := claudeDesktopRunningAppPath(); path != "" {
return claudeDesktopOpenAppPath(path)
}
return fmt.Errorf("Claude Desktop executable was not found; open Claude Desktop manually once and re-run 'ollama launch claude-desktop --restore'")
case "darwin":
return openClaudeDesktopDarwin()
default:
return claudeDesktopSupported()
}
}
func defaultClaudeDesktopOpenAppPath(path string) error {
switch claudeDesktopGOOS {
case "windows":
return exec.Command("powershell.exe", "-NoProfile", "-Command", "Start-Process -FilePath "+quotePowerShellString(path)).Run()
case "darwin":
return openClaudeDesktopDarwin()
default:
return claudeDesktopSupported()
}
}
func openClaudeDesktopDarwin() error {
cmd := exec.Command("open", "-a", "Claude")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func defaultClaudeDesktopRunningAppPath() string {
if claudeDesktopGOOS != "windows" {
return ""
}
script := `(Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 -and $_.Path } | Select-Object -First 1 -ExpandProperty Path)`
out, err := exec.Command("powershell.exe", "-NoProfile", "-Command", script).Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
func defaultClaudeDesktopQuitApp() error {
if claudeDesktopGOOS == "windows" {
script := `Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 } | ForEach-Object { [void]$_.CloseMainWindow() }`
return exec.Command("powershell.exe", "-NoProfile", "-Command", script).Run()
}
return exec.Command("osascript", "-e", `tell application "Claude" to quit`).Run()
}
func quotePowerShellString(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}

View File

@@ -0,0 +1,946 @@
package launch
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func withClaudeDesktopPlatform(t *testing.T, goos string) {
t.Helper()
old := claudeDesktopGOOS
claudeDesktopGOOS = goos
t.Cleanup(func() {
claudeDesktopGOOS = old
})
}
func withClaudeDesktopValidation(t *testing.T, fn func(context.Context, string) error) {
t.Helper()
old := claudeDesktopValidateAPIKey
claudeDesktopValidateAPIKey = fn
t.Cleanup(func() {
claudeDesktopValidateAPIKey = old
})
}
func withClaudeDesktopPrompt(t *testing.T, fn func() (string, error)) {
t.Helper()
old := claudeDesktopPromptAPIKey
claudeDesktopPromptAPIKey = fn
t.Cleanup(func() {
claudeDesktopPromptAPIKey = old
})
}
func withClaudeDesktopProcessHooks(t *testing.T, running func() bool, quit func() error, open func() error) {
t.Helper()
oldRunning := claudeDesktopIsRunning
oldQuit := claudeDesktopQuitApp
oldOpen := claudeDesktopOpenApp
oldOpenPath := claudeDesktopOpenAppPath
oldRunningPath := claudeDesktopRunningAppPath
oldSleep := claudeDesktopSleep
claudeDesktopIsRunning = running
claudeDesktopQuitApp = quit
claudeDesktopOpenApp = open
claudeDesktopOpenAppPath = oldOpenPath
claudeDesktopRunningAppPath = oldRunningPath
claudeDesktopSleep = func(time.Duration) {}
t.Cleanup(func() {
claudeDesktopIsRunning = oldRunning
claudeDesktopQuitApp = oldQuit
claudeDesktopOpenApp = oldOpen
claudeDesktopOpenAppPath = oldOpenPath
claudeDesktopRunningAppPath = oldRunningPath
claudeDesktopSleep = oldSleep
})
}
func claudeDesktopReadJSON(t *testing.T, path string) map[string]any {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read %s: %v", path, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("parse %s: %v", path, err)
}
return cfg
}
func TestClaudeDesktopIntegration(t *testing.T) {
c := &ClaudeDesktop{}
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements managed autodiscovery integration", func(t *testing.T) {
var _ ManagedAutodiscoveryIntegration = c
})
t.Run("does not use local Ollama Cloud auth gate", func(t *testing.T) {
if _, ok := any(c).(ManagedAutodiscoveryCloudIntegration); ok {
t.Fatal("Claude Desktop should validate OLLAMA_API_KEY directly instead of requiring local Ollama Cloud sign-in")
}
})
t.Run("implements restore", func(t *testing.T) {
var _ RestorableIntegration = c
})
t.Run("has restore hint", func(t *testing.T) {
var _ RestoreHintIntegration = c
if !strings.Contains(c.RestoreHint(), "--restore") {
t.Fatalf("expected restore hint to mention --restore, got %q", c.RestoreHint())
}
if strings.Contains(c.RestoreHint(), "Tip:") {
t.Fatalf("restore hint should not use Tip wording, got %q", c.RestoreHint())
}
})
t.Run("has success messages", func(t *testing.T) {
var _ ConfigurationSuccessIntegration = c
var _ RestoreSuccessIntegration = c
if got := c.ConfigurationSuccessMessage(); got != "Claude Desktop profile changed to Ollama Cloud.\nTo restore the usual Claude profile, run: ollama launch claude-desktop --restore" {
t.Fatalf("configuration success message = %q", got)
}
if got := c.RestoreSuccessMessage(); got != "Claude Desktop restored to the usual Claude profile." {
t.Fatalf("restore success message = %q", got)
}
})
t.Run("skips local model readiness", func(t *testing.T) {
var _ ManagedModelReadinessSkipper = c
if !c.SkipModelReadiness() {
t.Fatal("expected Claude Desktop to skip local model readiness")
}
})
}
func TestLaunchIntegration_ClaudeDesktopLaunchReturnsUnsupported(t *testing.T) {
for _, name := range []string{"claude-desktop", "claude-app"} {
t.Run(name, func(t *testing.T) {
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: name})
if err == nil {
t.Fatal("expected Claude Desktop launch to fail")
}
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
t.Fatalf("expected unsupported guidance, got %v", err)
}
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
t.Fatalf("expected restore guidance, got %v", err)
}
})
}
}
func TestLaunchIntegration_ClaudeDesktopRestoreStillWorks(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
if err := os.MkdirAll(filepath.Join(tmpDir, "Applications", "Claude.app"), 0o755); err != nil {
t.Fatal(err)
}
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer"}`), 0o644); err != nil {
t.Fatal(err)
}
stderr := captureStderr(t, func() {
err = LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "claude-desktop", Restore: true})
})
if err != nil {
t.Fatalf("LaunchIntegration restore returned error: %v", err)
}
if !strings.Contains(stderr, claudeDesktopRestoredMessage) {
t.Fatalf("expected restore success message, got stderr: %q", stderr)
}
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
if desktopConfig["deploymentMode"] != "1p" {
t.Fatalf("deploymentMode = %v, want 1p", desktopConfig["deploymentMode"])
}
}
func TestClaudeDesktopConfigureWritesOllamaCloudProfile(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "test-api-key")
var validatedKey string
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
validatedKey = key
return nil
})
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.desktopConfig), 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.meta), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.desktopConfig, []byte(`{"existing":true}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.meta, []byte(`{"entries":[{"id":"custom","name":"Custom"}]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
t.Fatalf("Configure returned error: %v", err)
}
if validatedKey != "test-api-key" {
t.Fatalf("validated key = %q, want test API key", validatedKey)
}
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
if desktopConfig["existing"] != true {
t.Fatalf("existing desktop config key was not preserved: %v", desktopConfig)
}
if desktopConfig["deploymentMode"] != "3p" {
t.Fatalf("deploymentMode = %v, want 3p", desktopConfig["deploymentMode"])
}
normalConfig := claudeDesktopReadJSON(t, paths.normalConfig)
if normalConfig["deploymentMode"] != "3p" {
t.Fatalf("normal deploymentMode = %v, want 3p", normalConfig["deploymentMode"])
}
meta := claudeDesktopReadJSON(t, paths.meta)
if meta["appliedId"] != claudeDesktopProfileID {
t.Fatalf("appliedId = %v, want %s", meta["appliedId"], claudeDesktopProfileID)
}
entries, _ := meta["entries"].([]any)
if len(entries) != 2 {
t.Fatalf("entries len = %d, want 2: %v", len(entries), entries)
}
profile := claudeDesktopReadJSON(t, paths.profile)
if profile["inferenceProvider"] != "gateway" {
t.Fatalf("inferenceProvider = %v, want gateway", profile["inferenceProvider"])
}
if profile["inferenceGatewayBaseUrl"] != claudeDesktopGatewayBaseURL {
t.Fatalf("base URL = %v, want %s", profile["inferenceGatewayBaseUrl"], claudeDesktopGatewayBaseURL)
}
if profile["inferenceGatewayApiKey"] != "test-api-key" {
t.Fatal("expected configured API key to be written")
}
if profile["inferenceGatewayAuthScheme"] != "bearer" {
t.Fatalf("auth scheme = %v, want bearer", profile["inferenceGatewayAuthScheme"])
}
if profile["disableDeploymentModeChooser"] != true {
t.Fatalf("disableDeploymentModeChooser = %v, want true", profile["disableDeploymentModeChooser"])
}
if _, ok := profile["inferenceModels"]; ok {
t.Fatalf("inferenceModels should be omitted so Claude can discover models, got %v", profile["inferenceModels"])
}
}
func TestClaudeDesktopConfigureAutodiscoveryRemovesExistingModelCatalog(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "test-api-key")
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"inferenceModels":["qwen3.5"],"inferenceGatewayApiKey":"old"}`), 0o644); err != nil {
t.Fatal(err)
}
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
}
profile := claudeDesktopReadJSON(t, paths.profile)
if _, ok := profile["inferenceModels"]; ok {
t.Fatalf("inferenceModels should be removed, got %v", profile["inferenceModels"])
}
if profile["inferenceGatewayApiKey"] != "test-api-key" {
t.Fatal("expected env API key to replace the old key")
}
}
func TestClaudeDesktopWindowsConfigPathsUseLocalAppData(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if want := filepath.Join(tmpDir, "LocalAppData", "Claude-3p", "claude_desktop_config.json"); paths.desktopConfig != want {
t.Fatalf("desktop config = %q, want %q", paths.desktopConfig, want)
}
if want := filepath.Join(tmpDir, "LocalAppData", "Claude", "claude_desktop_config.json"); paths.normalConfig != want {
t.Fatalf("normal config = %q, want %q", paths.normalConfig, want)
}
}
func TestClaudeDesktopWindowsConfigPathsFallbackToNestProfile(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
if err := os.MkdirAll(filepath.Join(local, "Claude Nest-3p"), 0o755); err != nil {
t.Fatal(err)
}
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if want := filepath.Join(local, "Claude Nest-3p", "claude_desktop_config.json"); paths.desktopConfig != want {
t.Fatalf("desktop config = %q, want %q", paths.desktopConfig, want)
}
}
func TestClaudeDesktopAutodiscoveryConfiguredOnWindows(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
t.Setenv("OLLAMA_API_KEY", "test-api-key")
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
c := &ClaudeDesktop{}
if err := c.ConfigureAutodiscovery(); err != nil {
t.Fatalf("Configure returned error: %v", err)
}
if !c.AutodiscoveryConfigured() {
t.Fatal("expected Claude Desktop autodiscovery config to be detected on Windows")
}
}
func TestClaudeDesktopConfigureAutodiscoveryTouchesAllWindowsProfileCandidates(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
t.Setenv("OLLAMA_API_KEY", "test-api-key")
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
targets, err := claudeDesktopTargetPaths()
if err != nil {
t.Fatal(err)
}
if len(targets.normalConfigs) != 2 {
t.Fatalf("normal config target count = %d, want 2", len(targets.normalConfigs))
}
if len(targets.thirdPartyProfiles) != 2 {
t.Fatalf("third-party target count = %d, want 2", len(targets.thirdPartyProfiles))
}
c := &ClaudeDesktop{}
if err := c.ConfigureAutodiscovery(); err != nil {
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
}
for _, path := range targets.normalConfigs {
cfg := claudeDesktopReadJSON(t, path)
if cfg["deploymentMode"] != "3p" {
t.Fatalf("%s deploymentMode = %v, want 3p", path, cfg["deploymentMode"])
}
}
for _, target := range targets.thirdPartyProfiles {
cfg := claudeDesktopReadJSON(t, target.desktopConfig)
if cfg["deploymentMode"] != "3p" {
t.Fatalf("%s deploymentMode = %v, want 3p", target.desktopConfig, cfg["deploymentMode"])
}
meta := claudeDesktopReadJSON(t, target.meta)
if meta["appliedId"] != claudeDesktopProfileID {
t.Fatalf("%s appliedId = %v, want %s", target.meta, meta["appliedId"], claudeDesktopProfileID)
}
profile := claudeDesktopReadJSON(t, target.profile)
if profile["inferenceProvider"] != "gateway" {
t.Fatalf("%s inferenceProvider = %v, want gateway", target.profile, profile["inferenceProvider"])
}
if profile["inferenceGatewayBaseUrl"] != claudeDesktopGatewayBaseURL {
t.Fatalf("%s base URL = %v, want %s", target.profile, profile["inferenceGatewayBaseUrl"], claudeDesktopGatewayBaseURL)
}
if profile["inferenceGatewayApiKey"] != "test-api-key" {
t.Fatalf("%s should contain the configured API key", target.profile)
}
if _, ok := profile["inferenceModels"]; ok {
t.Fatalf("%s inferenceModels should be omitted, got %v", target.profile, profile["inferenceModels"])
}
}
if !c.AutodiscoveryConfigured() {
t.Fatal("expected all Windows profile candidates to be considered configured")
}
if err := writeClaudeDesktopDeploymentMode(targets.thirdPartyProfiles[1].desktopConfig, "1p"); err != nil {
t.Fatal(err)
}
if c.AutodiscoveryConfigured() {
t.Fatal("expected a stale Windows candidate to force reconfiguration")
}
}
func TestClaudeDesktopInstalledOnWindowsRecognizesLocalProfileDir(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
if err := os.MkdirAll(filepath.Join(local, "Claude-3p"), 0o755); err != nil {
t.Fatal(err)
}
if !claudeDesktopInstalled() {
t.Fatal("expected Claude Desktop to be installed when the Windows profile directory exists")
}
}
func TestClaudeDesktopWindowsAppPathFindsAnthropicClaudeInstall(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
want := filepath.Join(local, "AnthropicClaude", "app-1.2.3", "Claude.exe")
if err := os.MkdirAll(filepath.Dir(want), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(want, []byte(""), 0o755); err != nil {
t.Fatal(err)
}
if got := claudeDesktopAppPath(); got != want {
t.Fatalf("claudeDesktopAppPath() = %q, want %q", got, want)
}
}
func TestWaitForClaudeDesktopExitUsesRunningHook(t *testing.T) {
withClaudeDesktopPlatform(t, "windows")
runningChecks := 0
withClaudeDesktopProcessHooks(t,
func() bool {
runningChecks++
return runningChecks == 1
},
func() error { return nil },
func() error { return nil },
)
if err := waitForClaudeDesktopExit(time.Second); err != nil {
t.Fatalf("waitForClaudeDesktopExit returned error: %v", err)
}
if runningChecks < 2 {
t.Fatalf("expected running hook to be checked until the visible window exits, got %d checks", runningChecks)
}
}
func TestClaudeDesktopWindowsRestoreRestartUsesCapturedDesktopPath(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
restoreConfirm := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
defer restoreConfirm()
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep"}`), 0o644); err != nil {
t.Fatal(err)
}
desktopPath := `C:\Users\parth\AppData\Local\AnthropicClaude\app-1.2.3\Claude.exe`
running := true
var openedPath string
withClaudeDesktopProcessHooks(t,
func() bool { return running },
func() error {
running = false
return nil
},
func() error {
t.Fatal("expected restart to open the captured Desktop executable path, not the generic launcher")
return nil
},
)
claudeDesktopRunningAppPath = func() string { return desktopPath }
claudeDesktopOpenAppPath = func(path string) error {
openedPath = path
return nil
}
if err := (&ClaudeDesktop{}).Restore(); err != nil {
t.Fatalf("Restore returned error: %v", err)
}
if openedPath != desktopPath {
t.Fatalf("opened path = %q, want %q", openedPath, desktopPath)
}
}
func TestClaudeDesktopWindowsOpenDoesNotFallBackToClaudeCommand(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
oldRunningPath := claudeDesktopRunningAppPath
claudeDesktopRunningAppPath = func() string { return "" }
t.Cleanup(func() { claudeDesktopRunningAppPath = oldRunningPath })
err := defaultClaudeDesktopOpenApp()
if err == nil || !strings.Contains(err.Error(), "Claude Desktop executable was not found") {
t.Fatalf("defaultClaudeDesktopOpenApp error = %v, want executable-not-found error", err)
}
}
func TestClaudeDesktopConfigureStopsBeforeWriteWhenKeyValidationFails(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "bad-key")
withClaudeDesktopValidation(t, func(context.Context, string) error {
return errors.New("invalid key")
})
err := (&ClaudeDesktop{}).ConfigureAutodiscovery()
if err == nil || !strings.Contains(err.Error(), "invalid key") {
t.Fatalf("Configure error = %v, want invalid key", err)
}
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(paths.desktopConfig); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("desktop config should not be written after validation failure, stat err = %v", err)
}
}
func TestValidateClaudeDesktopAPIKeyUsesClaudeModelsRoute(t *testing.T) {
oldClient := claudeDesktopHTTPClient
var gotPath, gotAuth string
claudeDesktopHTTPClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotPath = req.URL.Path
gotAuth = req.Header.Get("Authorization")
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"data":[]}`)),
Header: make(http.Header),
}, nil
})}
t.Cleanup(func() {
claudeDesktopHTTPClient = oldClient
})
if err := validateClaudeDesktopAPIKey(context.Background(), "test-key"); err != nil {
t.Fatalf("validateClaudeDesktopAPIKey returned error: %v", err)
}
if gotPath != "/v1/models" {
t.Fatalf("validation path = %q, want /v1/models", gotPath)
}
if gotAuth != "Bearer test-key" {
t.Fatalf("Authorization header = %q, want bearer key", gotAuth)
}
}
func TestValidateClaudeDesktopAPIKeyHidesInvalidHeaderDetails(t *testing.T) {
err := validateClaudeDesktopAPIKey(context.Background(), "bad\nkey")
if err == nil {
t.Fatal("expected validation error for key with newline")
}
if !strings.Contains(err.Error(), "could not verify Ollama API key") {
t.Fatalf("validation error = %v, want friendly verification message", err)
}
if strings.Contains(err.Error(), "invalid header") || strings.Contains(err.Error(), "net/http") {
t.Fatalf("validation error should not expose transport internals: %v", err)
}
if !strings.Contains(err.Error(), "https://ollama.com/settings/keys") {
t.Fatalf("validation error should include settings link: %v", err)
}
}
func TestClaudeDesktopConfigureRequiresAPIKey(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "")
withClaudeDesktopValidation(t, func(context.Context, string) error {
t.Fatal("validation should not run without an API key")
return nil
})
err := (&ClaudeDesktop{}).ConfigureAutodiscovery()
if err == nil || !strings.Contains(err.Error(), "OLLAMA_API_KEY is required") {
t.Fatalf("Configure error = %v, want missing key guidance", err)
}
}
func TestClaudeDesktopAPIKeyPromptIncludesSettingsLink(t *testing.T) {
prompt := claudeDesktopAPIKeyPrompt()
if !strings.Contains(prompt, "Enter Ollama API key") {
t.Fatalf("prompt should ask for the API key, got %q", prompt)
}
if !strings.Contains(prompt, "https://ollama.com/settings/keys") {
t.Fatalf("prompt should include API key settings link, got %q", prompt)
}
}
func TestClaudeDesktopConfigureReusesExistingAPIKey(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "")
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"inferenceGatewayApiKey":"existing-key"}`), 0o644); err != nil {
t.Fatal(err)
}
var validatedKey string
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
validatedKey = key
return nil
})
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
}
if validatedKey != "existing-key" {
t.Fatalf("validated key = %q, want existing-key", validatedKey)
}
}
func TestClaudeDesktopConfigureReplacesInvalidExistingAPIKey(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
withInteractiveSession(t, true)
t.Setenv("OLLAMA_API_KEY", "")
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"inferenceGatewayApiKey":"stale-key"}`), 0o644); err != nil {
t.Fatal(err)
}
var validated []string
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
validated = append(validated, key)
if key == "stale-key" {
return errors.New("invalid key")
}
return nil
})
withClaudeDesktopPrompt(t, func() (string, error) {
return "replacement-key", nil
})
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
}
if diff := compareStrings(validated, []string{"stale-key", "replacement-key"}); diff != "" {
t.Fatalf("validated keys mismatch: %s", diff)
}
profile := claudeDesktopReadJSON(t, paths.profile)
if profile["inferenceGatewayApiKey"] != "replacement-key" {
t.Fatalf("configured key = %v, want replacement-key", profile["inferenceGatewayApiKey"])
}
}
func TestClaudeDesktopConfigureReusesExistingAPIKeyFromAnyWindowsProfile(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
t.Setenv("OLLAMA_API_KEY", "")
targets, err := claudeDesktopTargetPaths()
if err != nil {
t.Fatal(err)
}
fallbackProfile := targets.thirdPartyProfiles[1].profile
if err := os.MkdirAll(filepath.Dir(fallbackProfile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(fallbackProfile, []byte(`{"inferenceGatewayApiKey":"fallback-key"}`), 0o644); err != nil {
t.Fatal(err)
}
var validatedKey string
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
validatedKey = key
return nil
})
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
}
if validatedKey != "fallback-key" {
t.Fatalf("validated key = %q, want fallback-key", validatedKey)
}
for _, target := range targets.thirdPartyProfiles {
profile := claudeDesktopReadJSON(t, target.profile)
if profile["inferenceGatewayApiKey"] != "fallback-key" {
t.Fatalf("%s should reuse fallback key, got %v", target.profile, profile["inferenceGatewayApiKey"])
}
}
}
func TestClaudeDesktopAutodiscoveryConfiguredRequiresAppliedOllamaProfile(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "test-api-key")
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
c := &ClaudeDesktop{}
if err := c.ConfigureAutodiscovery(); err != nil {
t.Fatalf("Configure returned error: %v", err)
}
if !c.AutodiscoveryConfigured() {
t.Fatal("expected Claude Desktop autodiscovery config to be detected")
}
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"custom"}`), 0o644); err != nil {
t.Fatal(err)
}
if c.AutodiscoveryConfigured() {
t.Fatal("expected another applied profile to hide Claude Desktop autodiscovery config")
}
}
func TestClaudeDesktopAutodiscoveryConfiguredRequiresAPIKey(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
t.Setenv("OLLAMA_API_KEY", "test-api-key")
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
c := &ClaudeDesktop{}
if err := c.ConfigureAutodiscovery(); err != nil {
t.Fatalf("Configure returned error: %v", err)
}
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
profile := claudeDesktopReadJSON(t, paths.profile)
delete(profile, "inferenceGatewayApiKey")
data, err := json.Marshal(profile)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, data, 0o644); err != nil {
t.Fatal(err)
}
if c.AutodiscoveryConfigured() {
t.Fatal("expected missing gateway API key to force Claude Desktop reconfiguration")
}
}
func TestClaudeDesktopRestoreSwitchesBackToFirstPartyMode(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "darwin")
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
paths, err := claudeDesktopConfigPaths()
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer","inferenceModels":["legacy"]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := (&ClaudeDesktop{}).Restore(); err != nil {
t.Fatalf("Restore returned error: %v", err)
}
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
if desktopConfig["deploymentMode"] != "1p" {
t.Fatalf("deploymentMode = %v, want 1p", desktopConfig["deploymentMode"])
}
normalConfig := claudeDesktopReadJSON(t, paths.normalConfig)
if normalConfig["deploymentMode"] != "1p" {
t.Fatalf("normal deploymentMode = %v, want 1p", normalConfig["deploymentMode"])
}
profile := claudeDesktopReadJSON(t, paths.profile)
if profile["disableDeploymentModeChooser"] != false {
t.Fatalf("disableDeploymentModeChooser = %v, want false", profile["disableDeploymentModeChooser"])
}
if profile["inferenceGatewayApiKey"] != "keep" {
t.Fatal("restore should leave existing Ollama profile credentials in place")
}
for _, key := range []string{"inferenceProvider", "inferenceGatewayBaseUrl", "inferenceGatewayAuthScheme", "inferenceModels"} {
if _, ok := profile[key]; ok {
t.Fatalf("restore should clear stale %s from the Ollama profile: %v", key, profile)
}
}
meta := claudeDesktopReadJSON(t, paths.meta)
if _, ok := meta["appliedId"]; ok {
t.Fatalf("restore should clear the applied Ollama third-party profile: %v", meta)
}
if (&ClaudeDesktop{}).AutodiscoveryConfigured() {
t.Fatal("restore should leave Claude Desktop autodiscovery unconfigured")
}
}
func TestClaudeDesktopRestoreTouchesAllWindowsProfileCandidates(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withClaudeDesktopPlatform(t, "windows")
local := filepath.Join(tmpDir, "LocalAppData")
t.Setenv("LOCALAPPDATA", local)
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
targets, err := claudeDesktopTargetPaths()
if err != nil {
t.Fatal(err)
}
if len(targets.normalConfigs) != 2 {
t.Fatalf("normal config target count = %d, want 2", len(targets.normalConfigs))
}
if len(targets.thirdPartyProfiles) != 2 {
t.Fatalf("third-party target count = %d, want 2", len(targets.thirdPartyProfiles))
}
for _, target := range targets.thirdPartyProfiles {
if err := os.MkdirAll(filepath.Dir(target.profile), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(target.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(target.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer","inferenceModels":["legacy"]}`), 0o644); err != nil {
t.Fatal(err)
}
}
if err := (&ClaudeDesktop{}).Restore(); err != nil {
t.Fatalf("Restore returned error: %v", err)
}
for _, path := range targets.normalConfigs {
cfg := claudeDesktopReadJSON(t, path)
if cfg["deploymentMode"] != "1p" {
t.Fatalf("%s deploymentMode = %v, want 1p", path, cfg["deploymentMode"])
}
}
for _, target := range targets.thirdPartyProfiles {
cfg := claudeDesktopReadJSON(t, target.desktopConfig)
if cfg["deploymentMode"] != "1p" {
t.Fatalf("%s deploymentMode = %v, want 1p", target.desktopConfig, cfg["deploymentMode"])
}
meta := claudeDesktopReadJSON(t, target.meta)
if _, ok := meta["appliedId"]; ok {
t.Fatalf("%s should not keep the Ollama applied profile: %v", target.meta, meta)
}
profile := claudeDesktopReadJSON(t, target.profile)
if profile["disableDeploymentModeChooser"] != false {
t.Fatalf("%s disableDeploymentModeChooser = %v, want false", target.profile, profile["disableDeploymentModeChooser"])
}
if profile["inferenceGatewayApiKey"] != "keep" {
t.Fatalf("%s should preserve gateway API key", target.profile)
}
for _, key := range []string{"inferenceProvider", "inferenceGatewayBaseUrl", "inferenceGatewayAuthScheme", "inferenceModels"} {
if _, ok := profile[key]; ok {
t.Fatalf("%s should clear stale %s: %v", target.profile, key, profile)
}
}
}
}
func TestClaudeDesktopRunReturnsUnsupported(t *testing.T) {
withClaudeDesktopPlatform(t, "darwin")
withClaudeDesktopProcessHooks(t,
func() bool {
t.Fatal("Run should not inspect Claude Desktop process state")
return false
},
func() error {
t.Fatal("Run should not quit Claude Desktop")
return nil
},
func() error {
t.Fatal("Run should not open Claude Desktop")
return nil
},
)
for _, args := range [][]string{nil, {"--foo"}} {
err := (&ClaudeDesktop{}).Run("qwen3.5", nil, args)
if err == nil {
t.Fatal("expected Run to fail")
}
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
t.Fatalf("expected unsupported guidance, got %v", err)
}
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
t.Fatalf("expected restore guidance, got %v", err)
}
}
}

171
cmd/launch/claude_test.go Normal file
View File

@@ -0,0 +1,171 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestClaudeIntegration(t *testing.T) {
c := &Claude{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Claude Code" {
t.Errorf("String() = %q, want %q", got, "Claude Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestClaudeFindPath(t *testing.T) {
c := &Claude{}
t.Run("finds claude in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(tmpDir, ".claude", "local", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestClaudeModelEnvVars(t *testing.T) {
c := &Claude{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
got := envMap(c.modelEnvVars("llama3.2"))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("supports empty model", func(t *testing.T) {
got := envMap(c.modelEnvVars(""))
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
}
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
}
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
}
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
}
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
got := envMap(c.modelEnvVars("glm-5:cloud"))
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
got := envMap(c.modelEnvVars("unknown-model:cloud"))
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
}
})
}

104
cmd/launch/cline.go Normal file
View File

@@ -0,0 +1,104 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// Cline implements Runner and Editor for the Cline CLI integration
type Cline struct{}
func (c *Cline) String() string { return "Cline" }
func (c *Cline) Run(model string, _ []LaunchModel, args []string) error {
if _, err := exec.LookPath("cline"); err != nil {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *Cline) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".cline", "data", "globalState.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Cline) Edit(models []LaunchModel) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
}
}
// Set Ollama as the provider for both act and plan modes
baseURL := envconfig.Host().String()
config["ollamaBaseUrl"] = baseURL
config["actModeApiProvider"] = "ollama"
config["actModeOllamaModelId"] = models[0].Name
config["actModeOllamaBaseUrl"] = baseURL
config["planModeApiProvider"] = "ollama"
config["planModeOllamaModelId"] = models[0].Name
config["planModeOllamaBaseUrl"] = baseURL
config["welcomeViewCompleted"] = true
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data, "cline")
}
func (c *Cline) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}
if config["actModeApiProvider"] != "ollama" {
return nil
}
modelID, _ := config["actModeOllamaModelId"].(string)
if modelID == "" {
return nil
}
return []string{modelID}
}

204
cmd/launch/cline_test.go Normal file
View File

@@ -0,0 +1,204 @@
package launch
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestClineIntegration(t *testing.T) {
c := &Cline{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Cline" {
t.Errorf("String() = %q, want %q", got, "Cline")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClineEdit(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var config map[string]any
json.Unmarshal(data, &config)
return config
}
t.Run("creates config from scratch", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud")); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeApiProvider"] != "ollama" {
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
}
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
}
if config["planModeApiProvider"] != "ollama" {
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
}
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
}
if config["welcomeViewCompleted"] != true {
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
}
})
t.Run("preserves existing fields", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
os.MkdirAll(configDir, 0o755)
existing := map[string]any{
"remoteRulesToggles": map[string]any{},
"remoteWorkflowToggles": map[string]any{},
"customSetting": "keep-me",
}
data, _ := json.Marshal(existing)
os.WriteFile(configPath, data, 0o644)
if err := c.Edit(testLaunchModels("glm-5:cloud")); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["customSetting"] != "keep-me" {
t.Errorf("customSetting was not preserved")
}
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
})
t.Run("updates model on re-edit", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud")); err != nil {
t.Fatal(err)
}
if err := c.Edit(testLaunchModels("glm-5:cloud")); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
if config["planModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(nil); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Error("expected no config file to be created for empty models")
}
})
t.Run("uses first model as primary", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud", "glm-5:cloud")); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
}
})
}
func TestClineModels(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
t.Run("returns nil when no config", func(t *testing.T) {
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "anthropic",
"actModeOllamaModelId": "some-model",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns model when ollama is configured", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "ollama",
"actModeOllamaModelId": "kimi-k2.5:cloud",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
}
})
}
func TestClinePaths(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns nil when no config exists", func(t *testing.T) {
if paths := c.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".cline", "data")
os.MkdirAll(configDir, 0o755)
configPath := filepath.Join(configDir, "globalState.json")
os.WriteFile(configPath, []byte("{}"), 0o644)
paths := c.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}

671
cmd/launch/codex.go Normal file
View File

@@ -0,0 +1,671 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
"github.com/pelletier/go-toml/v2"
"golang.org/x/mod/semver"
)
// Codex implements Runner for Codex integration
type Codex struct{}
func (c *Codex) String() string { return "Codex" }
const (
codexProfileName = "ollama-launch"
codexProviderName = "Ollama"
codexFallbackContextWindow = 128_000
codexRootProfileKey = "profile"
codexRootModelKey = "model"
codexRootModelProviderKey = "model_provider"
codexRootModelCatalogJSONKey = "model_catalog_json"
)
func (c *Codex) args(model, modelCatalogPath string, extra []string) []string {
args := []string{"--profile", codexProfileName}
if modelCatalogPath != "" {
args = append(args, "-c", fmt.Sprintf("%s=%q", codexRootModelCatalogJSONKey, modelCatalogPath))
}
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (c *Codex) Run(model string, models []LaunchModel, args []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
if err := ensureCodexConfig(model, models); err != nil {
return fmt.Errorf("failed to configure codex: %w", err)
}
catalogPath, err := codexModelCatalogPath()
if err != nil {
return fmt.Errorf("failed to configure codex: %w", err)
}
cmd := exec.Command("codex", c.args(model, catalogPath, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}
// ensureCodexConfig writes a Codex profile and model catalog so Codex uses the
// local Ollama server and has model metadata available.
func ensureCodexConfig(modelName string, models []LaunchModel) error {
configPath, err := codexConfigPath()
if err != nil {
return err
}
codexDir := filepath.Dir(configPath)
if err := os.MkdirAll(codexDir, 0o755); err != nil {
return err
}
catalogPath := codexModelCatalogPathForConfig(configPath)
if err := writeCodexModelCatalog(catalogPath, codexCatalogModel(modelName, models)); err != nil {
return err
}
return writeCodexProfile(configPath, catalogPath)
}
func codexConfigPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".codex", "config.toml"), nil
}
func codexModelCatalogPath() (string, error) {
configPath, err := codexConfigPath()
if err != nil {
return "", err
}
return codexModelCatalogPathForConfig(configPath), nil
}
func codexModelCatalogPathForConfig(configPath string) string {
return filepath.Join(filepath.Dir(configPath), "model.json")
}
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
// and model provider sections with the correct base URL.
func writeCodexProfile(configPath string, modelCatalogPath ...string) error {
opts := codexLaunchProfileOptions{
forceAPIAuth: true,
}
if len(modelCatalogPath) > 0 {
opts.modelCatalogPath = modelCatalogPath[0]
}
return writeCodexLaunchProfile(configPath, opts)
}
type codexLaunchProfileOptions struct {
activate bool
profileName string
forceAPIAuth bool
setRootModelConfig bool
model string
modelCatalogPath string
backupIntegration string
}
func writeCodexLaunchProfile(configPath string, opts codexLaunchProfileOptions) error {
baseURL := codexBaseURL()
profileName := codexLaunchProfileName(opts)
profileHeader := codexProfileHeaderFor(profileName)
providerHeader := codexProviderHeaderFor(profileName)
content, readErr := os.ReadFile(configPath)
text := ""
if readErr == nil {
text = string(content)
} else if !os.IsNotExist(readErr) {
return readErr
}
parsed, err := codexParseConfig(text)
if err != nil {
return err
}
model := strings.TrimSpace(opts.model)
if model == "" {
model = parsed.ProfileString(profileName, codexRootModelKey)
}
modelCatalogPath := strings.TrimSpace(opts.modelCatalogPath)
if modelCatalogPath == "" {
modelCatalogPath = parsed.ProfileString(profileName, codexRootModelCatalogJSONKey)
}
profileLines := []string{}
if model != "" {
profileLines = append(profileLines, fmt.Sprintf("%s = %q", codexRootModelKey, model))
}
profileLines = append(profileLines,
fmt.Sprintf("openai_base_url = %q", baseURL),
fmt.Sprintf("%s = %q", codexRootModelProviderKey, profileName),
)
if opts.forceAPIAuth {
profileLines = append(profileLines, `forced_login_method = "api"`)
}
if modelCatalogPath != "" {
profileLines = append(profileLines, fmt.Sprintf("%s = %q", codexRootModelCatalogJSONKey, modelCatalogPath))
}
sections := []struct {
header string
lines []string
}{
{
header: profileHeader,
lines: profileLines,
},
{
header: providerHeader,
lines: []string{
fmt.Sprintf("name = %q", codexProviderName),
fmt.Sprintf("base_url = %q", baseURL),
`wire_api = "responses"`,
},
},
}
if opts.activate {
text = codexSetRootStringValue(text, codexRootProfileKey, profileName)
}
if opts.setRootModelConfig {
if model != "" {
text = codexSetRootStringValue(text, codexRootModelKey, model)
}
text = codexSetRootStringValue(text, codexRootModelProviderKey, profileName)
if modelCatalogPath != "" {
text = codexSetRootStringValue(text, codexRootModelCatalogJSONKey, modelCatalogPath)
}
}
for _, s := range sections {
text = codexUpsertSection(text, s.header, s.lines)
}
parsed, err = codexParseConfig(text)
if err != nil {
return err
}
if err := codexValidateLaunchProfileText(parsed, profileName, opts, model, modelCatalogPath, baseURL); err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, []byte(text), opts.backupIntegration)
}
func codexLaunchProfileName(opts codexLaunchProfileOptions) string {
if name := strings.TrimSpace(opts.profileName); name != "" {
return name
}
return codexProfileName
}
func codexBaseURL() string {
return strings.TrimRight(envconfig.ConnectableHost().String(), "/") + "/v1/"
}
func codexProfileHeader() string {
return codexProfileHeaderFor(codexProfileName)
}
func codexProviderHeader() string {
return codexProviderHeaderFor(codexProfileName)
}
func codexProfileHeaderFor(profileName string) string {
return fmt.Sprintf("[profiles.%s]", profileName)
}
func codexProviderHeaderFor(profileName string) string {
return fmt.Sprintf("[model_providers.%s]", profileName)
}
func codexValidateLaunchProfileText(config codexParsedConfig, profileName string, opts codexLaunchProfileOptions, model, modelCatalogPath, baseURL string) error {
for _, check := range []struct {
path []string
want string
}{
{[]string{"profiles", profileName, "openai_base_url"}, baseURL},
{[]string{"profiles", profileName, codexRootModelProviderKey}, profileName},
{[]string{"model_providers", profileName, "name"}, codexProviderName},
{[]string{"model_providers", profileName, "base_url"}, baseURL},
{[]string{"model_providers", profileName, "wire_api"}, "responses"},
} {
if got, ok := config.String(check.path...); !ok || got != check.want {
return fmt.Errorf("generated Codex config missing %s = %q", strings.Join(check.path, "."), check.want)
}
}
if opts.forceAPIAuth {
if got, ok := config.String("profiles", profileName, "forced_login_method"); !ok || got != "api" {
return fmt.Errorf("generated Codex config missing profiles.%s.forced_login_method = %q", profileName, "api")
}
}
if model != "" {
if got, ok := config.String("profiles", profileName, codexRootModelKey); !ok || got != model {
return fmt.Errorf("generated Codex config missing profiles.%s.model = %q", profileName, model)
}
}
if modelCatalogPath != "" {
if got, ok := config.String("profiles", profileName, codexRootModelCatalogJSONKey); !ok || got != modelCatalogPath {
return fmt.Errorf("generated Codex config missing profiles.%s.model_catalog_json = %q", profileName, modelCatalogPath)
}
}
if opts.activate {
if got := config.RootString(codexRootProfileKey); got != profileName {
return fmt.Errorf("generated Codex config missing profile = %q", profileName)
}
}
if opts.setRootModelConfig {
if model != "" {
if got := config.RootString(codexRootModelKey); got != model {
return fmt.Errorf("generated Codex config missing model = %q", model)
}
}
if got := config.RootString(codexRootModelProviderKey); got != profileName {
return fmt.Errorf("generated Codex config missing model_provider = %q", profileName)
}
if modelCatalogPath != "" {
if got := config.RootString(codexRootModelCatalogJSONKey); got != modelCatalogPath {
return fmt.Errorf("generated Codex config missing model_catalog_json = %q", modelCatalogPath)
}
}
}
return nil
}
func codexUpsertSection(text, header string, lines []string) string {
block := strings.Join(append([]string{header}, lines...), "\n") + "\n"
if targetPath, ok := codexTableHeaderPath(header); ok {
if start, end, found := codexSectionRange(text, targetPath); found {
return text[:start] + block + text[end:]
}
}
if text != "" && !strings.HasSuffix(text, "\n") {
text += "\n"
}
if text != "" {
text += "\n"
}
return text + block
}
func codexRemoveSection(text, header string) string {
targetPath, ok := codexTableHeaderPath(header)
if !ok {
return text
}
start, end, found := codexSectionRange(text, targetPath)
if !found {
return text
}
return text[:start] + text[end:]
}
type codexParsedConfig struct {
values map[string]any
}
func (c codexParsedConfig) String(path ...string) (string, bool) {
if len(path) == 0 {
return "", false
}
var current any = c.values
for _, part := range path {
table, ok := current.(map[string]any)
if !ok {
return "", false
}
current, ok = table[part]
if !ok {
return "", false
}
}
value, ok := current.(string)
if !ok {
return "", false
}
return value, true
}
func (c codexParsedConfig) RootString(key string) string {
value, _ := c.RootStringOK(key)
return value
}
func (c codexParsedConfig) RootStringOK(key string) (string, bool) {
return c.String(key)
}
func (c codexParsedConfig) ProfileString(profileName, key string) string {
value, _ := c.String("profiles", profileName, key)
return value
}
func (c codexParsedConfig) ProviderString(profileName, key string) string {
value, _ := c.String("model_providers", profileName, key)
return value
}
func codexRootStringValue(text, key string) string {
config, err := codexParseConfig(text)
if err != nil {
return ""
}
return config.RootString(key)
}
func codexRootStringValueOK(text, key string) (string, bool) {
config, err := codexParseConfig(text)
if err != nil {
return "", false
}
return config.RootStringOK(key)
}
func codexStringValue(text string, path ...string) (string, bool) {
config, err := codexParseConfig(text)
if err != nil {
return "", false
}
return config.String(path...)
}
func codexSectionStringValue(text, header, key string) string {
path, ok := codexTableHeaderPath(header)
if !ok {
return ""
}
value, _ := codexStringValue(text, append(path, key)...)
return value
}
func codexParseConfig(text string) (codexParsedConfig, error) {
values, err := codexParseConfigText(text)
if err != nil {
return codexParsedConfig{}, err
}
return codexParsedConfig{values: values}, nil
}
func codexParseConfigText(text string) (map[string]any, error) {
cfg := map[string]any{}
if strings.TrimSpace(text) == "" {
return cfg, nil
}
if err := toml.Unmarshal([]byte(text), &cfg); err != nil {
return nil, fmt.Errorf("invalid Codex config TOML: %w", err)
}
return cfg, nil
}
func codexValidateConfigText(text string) error {
_, err := codexParseConfig(text)
return err
}
func codexSectionRange(text string, targetPath []string) (int, int, bool) {
lines := strings.SplitAfter(text, "\n")
offset := 0
start := -1
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if !strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "#") {
offset += len(line)
continue
}
if start >= 0 {
return start, offset, true
}
if path, ok := codexTableHeaderPath(trimmed); ok && codexSamePath(path, targetPath) {
start = offset
}
offset += len(line)
}
if start >= 0 {
return start, len(text), true
}
return 0, 0, false
}
func codexTableHeaderPath(header string) ([]string, bool) {
trimmed := strings.TrimSpace(header)
if !strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "[[") {
return nil, false
}
const probeKey = "__ollama_launch_probe"
cfg := map[string]any{}
if err := toml.Unmarshal([]byte(trimmed+"\n"+probeKey+" = true\n"), &cfg); err != nil {
return nil, false
}
return codexFindProbePath(cfg, probeKey, nil)
}
func codexFindProbePath(value any, probeKey string, path []string) ([]string, bool) {
table, ok := value.(map[string]any)
if !ok {
return nil, false
}
if probe, ok := table[probeKey].(bool); ok && probe {
return path, true
}
for key, child := range table {
if key == probeKey {
continue
}
if childPath, ok := codexFindProbePath(child, probeKey, append(path, key)); ok {
return childPath, true
}
}
return nil, false
}
func codexSamePath(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func codexSetRootStringValue(text, key, value string) string {
lines := strings.SplitAfter(text, "\n")
rootEnd := len(lines)
for i, line := range lines {
if strings.HasPrefix(strings.TrimSpace(line), "[") {
rootEnd = i
break
}
}
assignment := fmt.Sprintf("%s = %q", key, value)
for i := range rootEnd {
line := lines[i]
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
if codexRootLineHasKey(trimmed, key) {
if strings.HasSuffix(line, "\n") {
lines[i] = assignment + "\n"
} else {
lines[i] = assignment
}
return strings.Join(lines, "")
}
}
insert := assignment + "\n"
root := strings.Join(lines[:rootEnd], "")
rest := strings.Join(lines[rootEnd:], "")
if root != "" && !strings.HasSuffix(root, "\n") {
root += "\n"
}
if rest != "" && !strings.HasSuffix(insert, "\n\n") {
insert += "\n"
}
return root + insert + rest
}
func codexRemoveRootValue(text, key string) string {
lines := strings.SplitAfter(text, "\n")
rootEnd := len(lines)
for i, line := range lines {
if strings.HasPrefix(strings.TrimSpace(line), "[") {
rootEnd = i
break
}
}
out := make([]string, 0, len(lines))
for i, line := range lines {
if i < rootEnd {
trimmed := strings.TrimSpace(line)
if trimmed != "" && !strings.HasPrefix(trimmed, "#") && codexRootLineHasKey(trimmed, key) {
continue
}
}
out = append(out, line)
}
return strings.Join(out, "")
}
func codexRootLineHasKey(line, key string) bool {
cfg := map[string]any{}
if err := toml.Unmarshal([]byte(line+"\n"), &cfg); err != nil {
return false
}
_, ok := cfg[key]
return ok
}
func codexCatalogModel(modelName string, models []LaunchModel) LaunchModel {
if model, ok := findLaunchModel(models, modelName); ok {
return model.WithCloudLimits()
}
return fallbackLaunchModel(modelName)
}
func writeCodexModelCatalog(catalogPath string, model LaunchModel) error {
entry := buildCodexModelEntry(model)
catalog := map[string]any{
"models": []any{entry},
}
data, err := json.MarshalIndent(catalog, "", " ")
if err != nil {
return err
}
return os.WriteFile(catalogPath, data, 0o644)
}
func buildCodexModelEntry(launchModel LaunchModel) map[string]any {
modelName := launchModel.Name
contextWindow := codexFallbackContextWindow
systemPrompt := ""
if launchModel.ContextLength > 0 {
contextWindow = launchModel.ContextLength
} else if launchModel.Details.ContextLength > 0 {
contextWindow = launchModel.Details.ContextLength
}
if l, ok := lookupCloudModelLimit(modelName); ok {
contextWindow = l.Context
}
if !isCloudModelName(modelName) && launchModel.Details.Format != "safetensors" {
if ctxLen := envconfig.ContextLength(); ctxLen > 0 {
contextWindow = int(ctxLen)
}
}
modalities := []string{"text"}
if launchModel.HasCapability(model.CapabilityVision) {
modalities = append(modalities, "image")
}
truncationMode := "bytes"
if isCloudModelName(modelName) {
truncationMode = "tokens"
}
return map[string]any{
"slug": modelName,
"display_name": modelName,
"context_window": contextWindow,
"shell_type": "default",
"visibility": "list",
"supported_in_api": true,
"priority": 0,
"truncation_policy": map[string]any{"mode": truncationMode, "limit": 10000},
"input_modalities": modalities,
"base_instructions": systemPrompt,
"support_verbosity": true,
"default_verbosity": "low",
"supports_parallel_tool_calls": false,
"supports_reasoning_summaries": false,
"supported_reasoning_levels": []any{},
"experimental_supported_tools": []any{},
}
}
func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
}
out, err := exec.Command("codex", "--version").Output()
if err != nil {
return fmt.Errorf("failed to get codex version: %w", err)
}
// Parse output like "codex-cli 0.87.0"
fields := strings.Fields(strings.TrimSpace(string(out)))
if len(fields) < 2 {
return fmt.Errorf("unexpected codex version output: %s", string(out))
}
version := "v" + fields[len(fields)-1]
minVersion := "v0.81.0"
if semver.Compare(version, minVersion) < 0 {
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
}
return nil
}

1042
cmd/launch/codex_app.go Normal file

File diff suppressed because it is too large Load Diff

1436
cmd/launch/codex_app_test.go Normal file

File diff suppressed because it is too large Load Diff

595
cmd/launch/codex_test.go Normal file
View File

@@ -0,0 +1,595 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
modelpkg "github.com/ollama/ollama/types/model"
)
func TestCodexArgs(t *testing.T) {
c := &Codex{}
catalogPath := filepath.Join("tmp", "model.json")
catalogArg := fmt.Sprintf("%s=%q", codexRootModelCatalogJSONKey, catalogPath)
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--profile", "ollama-launch", "-c", catalogArg}},
{"with model and extra args", "qwen3.5", []string{"-p", "myprofile"}, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "qwen3.5", "-p", "myprofile"}},
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "llama3.2", "--sandbox", "workspace-write"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, catalogPath, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestWriteCodexProfile(t *testing.T) {
t.Run("creates new file when none exists", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatal(err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
if !strings.Contains(content, "/v1/") {
t.Error("missing /v1/ suffix in base URL")
}
if !strings.Contains(content, `forced_login_method = "api"`) {
t.Error("missing forced_login_method key")
}
if !strings.Contains(content, `model_provider = "ollama-launch"`) {
t.Error("missing model_provider key")
}
if !strings.Contains(content, fmt.Sprintf("model_catalog_json = %q", catalogPath)) {
t.Error("missing model_catalog_json key")
}
if !strings.Contains(content, "[model_providers.ollama-launch]") {
t.Error("missing [model_providers.ollama-launch] section")
}
if !strings.Contains(content, `name = "Ollama"`) {
t.Error("missing model provider name")
}
if err := codexValidateConfigText(content); err != nil {
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
}
})
t.Run("appends profile to existing file without profile", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[some_other_section]\nkey = \"value\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[some_other_section]") {
t.Error("existing section was removed")
}
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
})
t.Run("replaces existing profile section", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
if err := codexValidateConfigText(content); err != nil {
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
}
})
t.Run("replaces equivalent quoted profile table", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "" +
`profile = "default"` + "\n\n" +
`[profiles."ollama-launch"]` + "\n" +
`openai_base_url = "http://old:1234/v1/"` + "\n\n" +
`[model_providers."ollama-launch"]` + "\n" +
`name = "Old"` + "\n" +
`base_url = "http://old:1234/v1/"` + "\n\n" +
`[profiles.default]` + "\n" +
`model = "gpt-5.5"` + "\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, `profiles."ollama-launch"`) {
t.Fatalf("quoted profile table should be replaced, got:\n%s", content)
}
if strings.Contains(content, "old:1234") {
t.Fatalf("old URL was not replaced, got:\n%s", content)
}
if got := codexSectionStringValue(content, codexProfileHeader(), "model_provider"); got != codexProfileName {
t.Fatalf("profile model_provider = %q, want %q", got, codexProfileName)
}
if got := codexSectionStringValue(content, codexProviderHeader(), "base_url"); !strings.Contains(got, "/v1/") {
t.Fatalf("provider base_url = %q, want /v1/ URL", got)
}
if err := codexValidateConfigText(content); err != nil {
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
}
})
t.Run("rejects invalid existing toml without writing", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "profile = \n"
os.WriteFile(configPath, []byte(existing), 0o644)
err := writeCodexProfile(configPath)
if err == nil || !strings.Contains(err.Error(), "invalid Codex config TOML") {
t.Fatalf("writeCodexProfile error = %v, want invalid TOML", err)
}
data, _ := os.ReadFile(configPath)
if string(data) != existing {
t.Fatalf("invalid config should be left untouched, got:\n%s", data)
}
})
t.Run("rejects malformed existing toml variants without writing", func(t *testing.T) {
tests := map[string]string{
"duplicate root key": "profile = \"default\"\nprofile = \"other\"\n",
"unterminated string": "model = \"gpt-5.5\n",
"bad table": "[profiles.ollama-launch\nmodel = \"llama3.2\"\n",
"duplicate table key": "[profiles.ollama-launch]\nmodel = \"a\"\nmodel = \"b\"\n",
}
for name, existing := range tests {
t.Run(name, func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
if err := os.WriteFile(configPath, []byte(existing), 0o644); err != nil {
t.Fatal(err)
}
err := writeCodexProfile(configPath)
if err == nil || !strings.Contains(err.Error(), "invalid Codex config TOML") {
t.Fatalf("writeCodexProfile error = %v, want invalid TOML", err)
}
data, _ := os.ReadFile(configPath)
if string(data) != existing {
t.Fatalf("invalid config should be left untouched, got:\n%s", data)
}
})
}
})
t.Run("backs up previous config before overwrite", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
t.Fatal(err)
}
existing := "# original-codex-backup-marker\n[profiles.default]\nmodel = \"gpt-5.5\"\n"
if err := os.WriteFile(configPath, []byte(existing), 0o644); err != nil {
t.Fatal(err)
}
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
assertBackupContains(t, filepath.Join(fileutil.BackupDir(), "config.toml.*"), "original-codex-backup-marker")
})
t.Run("updates equivalent quoted root keys", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
existing := "" +
`"profile" = "default"` + "\n" +
`"model" = "gpt-5.5"` + "\n" +
`"model_provider" = "openai"` + "\n\n" +
`[profiles.default]` + "\n" +
`model = "gpt-5.5"` + "\n"
os.WriteFile(configPath, []byte(existing), 0o644)
err := writeCodexLaunchProfile(configPath, codexLaunchProfileOptions{
activate: true,
setRootModelConfig: true,
model: "llama3.2",
})
if err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
for key, want := range map[string]string{
"profile": codexProfileName,
"model": "llama3.2",
"model_provider": codexProfileName,
} {
if got := codexRootStringValue(content, key); got != want {
t.Fatalf("root %s = %q, want %q in:\n%s", key, got, want, content)
}
}
if strings.Contains(content, `"profile"`) || strings.Contains(content, `"model_provider"`) {
t.Fatalf("quoted root keys should be rewritten once, got:\n%s", content)
}
if err := codexValidateConfigText(content); err != nil {
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
}
})
t.Run("replaces profile while preserving following sections", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "old:1234") {
t.Error("old URL was not replaced")
}
if !strings.Contains(content, "[another_section]") {
t.Error("following section was removed")
}
if !strings.Contains(content, "foo = \"bar\"") {
t.Error("following section content was removed")
}
})
t.Run("appends newline to file not ending with newline", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[other]\nkey = \"val\""
os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
// Should not have double blank lines from missing trailing newline
if strings.Contains(content, "\n\n\n") {
t.Error("unexpected triple newline in output")
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if !strings.Contains(content, "myhost:9999/v1/") {
t.Errorf("expected custom host in URL, got:\n%s", content)
}
})
t.Run("uses connectable host for unspecified bind address", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml")
if err := writeCodexProfile(configPath); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Contains(content, "0.0.0.0") {
t.Fatalf("config should not write bind-only host, got:\n%s", content)
}
if !strings.Contains(content, "127.0.0.1:11434/v1/") {
t.Fatalf("expected connectable loopback URL, got:\n%s", content)
}
})
}
func TestEnsureCodexConfig(t *testing.T) {
t.Run("creates .codex dir and config.toml", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("config.toml not created: %v", err)
}
content := string(data)
if !strings.Contains(content, "[profiles.ollama-launch]") {
t.Error("missing [profiles.ollama-launch] header")
}
if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key")
}
catalogPath := filepath.Join(tmpDir, ".codex", "model.json")
data, err = os.ReadFile(catalogPath)
if err != nil {
t.Fatalf("model.json not created: %v", err)
}
if !strings.Contains(string(data), `"slug": "llama3.2"`) {
t.Error("missing model catalog entry for selected model")
}
})
t.Run("is idempotent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
t.Fatal(err)
}
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
data, _ := os.ReadFile(configPath)
content := string(data)
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
t.Errorf("expected exactly one [profiles.ollama-launch] section after two calls, got %d", strings.Count(content, "[profiles.ollama-launch]"))
}
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
t.Errorf("expected exactly one [model_providers.ollama-launch] section after two calls, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
}
})
}
func assertBackupContains(t *testing.T, pattern, marker string) {
t.Helper()
backups, err := filepath.Glob(pattern)
if err != nil {
t.Fatal(err)
}
for _, backupPath := range backups {
data, err := os.ReadFile(backupPath)
if err != nil {
t.Fatal(err)
}
if strings.Contains(string(data), marker) {
return
}
}
t.Fatalf("backup matching %q with marker %q not found", pattern, marker)
}
func TestModelInfoContextLength(t *testing.T) {
tests := []struct {
name string
modelInfo map[string]any
want int
}{
{"float64 value", map[string]any{"qwen3_5_moe.context_length": float64(262144)}, 262144},
{"int value", map[string]any{"llama.context_length": 131072}, 131072},
{"no context_length key", map[string]any{"llama.embedding_length": float64(4096)}, 0},
{"empty map", map[string]any{}, 0},
{"nil map", nil, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := modelInfoContextLength(tt.modelInfo)
if got != tt.want {
t.Errorf("modelInfoContextLength() = %d, want %d", got, tt.want)
}
})
}
}
func TestBuildCodexModelEntryContextWindow(t *testing.T) {
tests := []struct {
name string
model LaunchModel
envContextLen string
wantContext int
}{
{
name: "inventory context length as fallback",
model: LaunchModel{
Name: "llama3.2",
ContextLength: 131072,
Details: api.ModelDetails{Format: "gguf"},
},
wantContext: 131072,
},
{
name: "details context length is used when model context is empty",
model: LaunchModel{
Name: "llama3.2",
Details: api.ModelDetails{Format: "gguf", ContextLength: 131072},
},
wantContext: 131072,
},
{
name: "OLLAMA_CONTEXT_LENGTH overrides local gguf inventory context",
model: LaunchModel{
Name: "llama3.2",
ContextLength: 131072,
Details: api.ModelDetails{Format: "gguf"},
},
envContextLen: "64000",
wantContext: 64000,
},
{
name: "safetensors uses inventory context only",
model: LaunchModel{
Name: "llama3.2",
ContextLength: 131072,
Details: api.ModelDetails{Format: "safetensors"},
},
envContextLen: "64000",
wantContext: 131072,
},
{
name: "cloud model uses hardcoded limits",
model: LaunchModel{
Name: "qwen3.5:cloud",
ContextLength: 131072,
Details: api.ModelDetails{Format: "gguf"},
},
envContextLen: "64000",
wantContext: 262144,
},
{
name: "unknown cloud model without metadata uses fallback context",
model: LaunchModel{
Name: "deepseek-v4-pro:cloud",
},
envContextLen: "64000",
wantContext: codexFallbackContextWindow,
},
{
name: "vision capability without reasoning advertisement",
model: LaunchModel{
Name: "llama3.2",
ContextLength: 131072,
Details: api.ModelDetails{Format: "gguf"},
Capabilities: []modelpkg.Capability{modelpkg.CapabilityVision, modelpkg.CapabilityThinking},
},
wantContext: 131072,
},
{
name: "missing metadata uses fallback context",
model: LaunchModel{Name: "llama3.2"},
wantContext: codexFallbackContextWindow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envContextLen != "" {
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
} else {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
}
entry := buildCodexModelEntry(tt.model)
gotContext, _ := entry["context_window"].(int)
if gotContext != tt.wantContext {
t.Errorf("context_window = %d, want %d", gotContext, tt.wantContext)
}
if tt.name == "vision capability without reasoning advertisement" {
modalities, _ := entry["input_modalities"].([]string)
if !slices.Contains(modalities, "image") {
t.Error("expected image in input_modalities")
}
levels, _ := entry["supported_reasoning_levels"].([]any)
if len(levels) != 0 {
t.Errorf("supported_reasoning_levels length = %d, want 0", len(levels))
}
if got, _ := entry["supports_reasoning_summaries"].(bool); got {
t.Error("supports_reasoning_summaries = true, want false")
}
}
if tt.name == "cloud model uses hardcoded limits" {
truncationPolicy, _ := entry["truncation_policy"].(map[string]any)
if mode, _ := truncationPolicy["mode"].(string); mode != "tokens" {
t.Errorf("truncation_policy mode = %q, want %q", mode, "tokens")
}
}
requiredKeys := []string{"slug", "display_name", "shell_type"}
for _, key := range requiredKeys {
if _, ok := entry[key]; !ok {
t.Errorf("missing required key %q", key)
}
}
if _, ok := entry["apply_patch_tool_type"]; ok {
t.Error("apply_patch_tool_type should be omitted so Codex CLI defaults can handle schema changes")
}
if _, err := json.Marshal(entry); err != nil {
t.Errorf("entry is not JSON serializable: %v", err)
}
})
}
}

689
cmd/launch/command_test.go Normal file
View File

@@ -0,0 +1,689 @@
package launch
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config"
"github.com/spf13/cobra"
)
func captureStderr(t *testing.T, fn func()) string {
t.Helper()
oldStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create stderr pipe: %v", err)
}
os.Stderr = w
defer func() {
os.Stderr = oldStderr
}()
done := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
done <- buf.String()
}()
fn()
_ = w.Close()
return <-done
}
func TestLaunchCmd(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
mockTUI := func(cmd *cobra.Command) {}
cmd := LaunchCmd(mockCheck, mockTUI)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
}
if cmd.Long == "" {
t.Error("Long description should not be empty")
}
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
if !strings.Contains(cmd.Long, "kimi") {
t.Error("Long description should mention kimi")
}
})
t.Run("flags exist", func(t *testing.T) {
if cmd.Flags().Lookup("model") == nil {
t.Error("--model flag should exist")
}
if cmd.Flags().Lookup("config") == nil {
t.Error("--config flag should exist")
}
if cmd.Flags().Lookup("restore") == nil {
t.Error("--restore flag should exist")
}
if cmd.Flags().Lookup("yes") == nil {
t.Error("--yes flag should exist")
}
})
t.Run("PreRunE is set", func(t *testing.T) {
if cmd.PreRunE == nil {
t.Error("PreRunE should be set to checkServerHeartbeat")
}
})
}
func TestLaunchCmdTUICallback(t *testing.T) {
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
t.Run("no args calls TUI", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{})
_ = cmd.Execute()
if !tuiCalled {
t.Error("TUI callback should be called when no args provided")
}
})
t.Run("integration arg bypasses TUI", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"claude"})
_ = cmd.Execute()
if tuiCalled {
t.Error("TUI callback should NOT be called when integration arg provided")
}
})
t.Run("--model flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --model without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --model is provided without an integration")
}
})
t.Run("--config flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--config"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --config without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --config is provided without an integration")
}
})
t.Run("--yes flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --yes without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
}
})
t.Run("extra args without integration return error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected flags and extra args without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
}
})
t.Run("--restore flag without integration returns error", func(t *testing.T) {
tuiCalled := false
mockTUI := func(cmd *cobra.Command) {
tuiCalled = true
}
cmd := LaunchCmd(mockCheck, mockTUI)
cmd.SetArgs([]string{"--restore"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected --restore without an integration to fail")
}
if !strings.Contains(err.Error(), "require an integration name") {
t.Fatalf("expected integration-name guidance, got %v", err)
}
if tuiCalled {
t.Error("TUI callback should NOT be called when --restore is provided without an integration")
}
})
}
func TestLaunchCmdClaudeDesktopLaunchReturnsUnsupported(t *testing.T) {
for _, name := range []string{"claude-desktop", "claude-app"} {
t.Run(name, func(t *testing.T) {
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error {
t.Fatal("heartbeat check should not run before Claude Desktop unsupported error")
return nil
}, func(cmd *cobra.Command) {
t.Fatal("TUI callback should not run for direct integration launch")
})
cmd.SetArgs([]string{name})
err := cmd.Execute()
if err == nil {
t.Fatal("expected Claude Desktop launch command to fail")
}
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
t.Fatalf("expected unsupported guidance, got %v", err)
}
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
t.Fatalf("expected restore guidance, got %v", err)
}
})
}
}
func TestLaunchCmdNilHeartbeat(t *testing.T) {
cmd := LaunchCmd(nil, nil)
if cmd == nil {
t.Fatal("LaunchCmd returned nil")
}
if cmd.PreRunE != nil {
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
}
}
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/show":
fmt.Fprintf(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
saved, err := config.LoadIntegration("stubeditor")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var selectorCalls int
var gotCurrent string
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
selectorCalls++
gotCurrent = current
return "llama3.2", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
stderr := captureStderr(t, func() {
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
})
if selectorCalls != 1 {
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
}
if gotCurrent != "" {
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
}
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdAutodiscoveryDefaultLaunchDoesNotForceConfigure(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
runner := &launcherManagedAutodiscoveryRunner{
autodiscoveryConfigured: true,
}
restore := OverrideIntegration("stubauto", runner)
defer restore()
if err := config.SaveIntegration("stubauto", []string{"Ollama Cloud"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
if err := config.MarkIntegrationOnboarded("stubauto"); err != nil {
t.Fatalf("failed to mark integration onboarded: %v", err)
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {
t.Fatal("TUI callback should not run for direct integration launch")
})
cmd.SetArgs([]string{"stubauto"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
if runner.autodiscoveryConfigures != 0 {
t.Fatalf("expected default autodiscovery launch to reuse existing config, got %d configures", runner.autodiscoveryConfigures)
}
if runner.ranModel != "Ollama Cloud" {
t.Fatalf("expected launch to run autodiscovery label, got %q", runner.ranModel)
}
}
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt with --yes: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
var pullCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"model not found"}`)
case "/api/pull":
pullCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"success"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command with --yes failed: %v", err)
}
if !pullCalled {
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
}
if stub.ranModel != "missing-model" {
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessWithoutYes_AllowsConfiguredLaunch(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
case "/api/status":
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, `{"error":"not found"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
restore := OverrideIntegration("stubeditor", stub)
defer restore()
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
return false, nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
err := cmd.Execute()
if err != nil {
t.Fatalf("expected launch command to succeed without --yes when an explicit model is provided, got %v", err)
}
if diff := compareStringSlices(stub.edited, [][]string{{"llama3.2"}}); diff != "" {
t.Fatalf("unexpected editor writes (-want +got):\n%s", diff)
}
if stub.ranModel != "llama3.2" {
t.Fatalf("expected launch to run configured model, got %q", stub.ranModel)
}
}
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
var gotCurrent string
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
gotCurrent = current
return "qwen3:8b", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp"})
if err := cmd.Execute(); err != nil {
t.Fatalf("launch command failed: %v", err)
}
if gotCurrent != "llama3.2" {
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
}
if stub.ranModel != "qwen3:8b" {
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
}
saved, err := config.LoadIntegration("stubapp")
if err != nil {
t.Fatalf("failed to reload integration config: %v", err)
}
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
}
}
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
t.Fatalf("failed to seed saved config: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model":"llama3.2"}`)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes saved-model launch")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, false)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
stub := &launcherSingleRunner{}
restore := OverrideIntegration("stubapp", stub)
defer restore()
oldSelector := DefaultSingleSelector
defer func() { DefaultSingleSelector = oldSelector }()
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
t.Fatal("selector should not be called for headless --yes without saved model")
return "", nil
}
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
cmd.SetArgs([]string{"stubapp", "--yes"})
err := cmd.Execute()
if err == nil {
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
}
if !strings.Contains(err.Error(), "requires --model <model>") {
t.Fatalf("expected actionable --model guidance, got %v", err)
}
if stub.ranModel != "" {
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
}
}

76
cmd/launch/copilot.go Normal file
View File

@@ -0,0 +1,76 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Copilot implements Runner for GitHub Copilot CLI integration.
type Copilot struct{}
func (c *Copilot) String() string { return "Copilot CLI" }
func (c *Copilot) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Copilot) findPath() (string, error) {
if p, err := exec.LookPath("copilot"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(home, ".local", "bin", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Copilot) Run(model string, _ []LaunchModel, args []string) error {
copilotPath, err := c.findPath()
if err != nil {
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
}
cmd := exec.Command(copilotPath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), c.envVars(model)...)
return cmd.Run()
}
// envVars returns the environment variables that configure Copilot CLI
// to use Ollama as its model provider.
func (c *Copilot) envVars(model string) []string {
env := []string{
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
"COPILOT_PROVIDER_API_KEY=",
"COPILOT_PROVIDER_WIRE_API=responses",
}
if model != "" {
env = append(env, "COPILOT_MODEL="+model)
}
return env
}

161
cmd/launch/copilot_test.go Normal file
View File

@@ -0,0 +1,161 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestCopilotIntegration(t *testing.T) {
c := &Copilot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Copilot CLI" {
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestCopilotFindPath(t *testing.T) {
c := &Copilot{}
t.Run("finds copilot in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("returns error when not in PATH", func(t *testing.T) {
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(tmpDir, ".local", "bin", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestCopilotArgs(t *testing.T) {
c := &Copilot{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestCopilotEnvVars(t *testing.T) {
c := &Copilot{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("sets required provider env vars with model", func(t *testing.T) {
got := envMap(c.envVars("llama3.2"))
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
}
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
}
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
}
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
}
if got["COPILOT_MODEL"] != "llama3.2" {
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
}
})
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
got := envMap(c.envVars(""))
if _, ok := got["COPILOT_MODEL"]; ok {
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
got := envMap(c.envVars("test"))
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
}
})
}

188
cmd/launch/droid.go Normal file
View File

@@ -0,0 +1,188 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidSettings represents the Droid settings.json file (only fields we use)
type droidSettings struct {
CustomModels []modelEntry `json:"customModels"`
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
}
type sessionSettings struct {
Model string `json:"model"`
ReasoningEffort string `json:"reasoningEffort"`
}
type modelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"`
Provider string `json:"provider"`
MaxOutputTokens int `json:"maxOutputTokens"`
SupportsImages bool `json:"supportsImages"`
ID string `json:"id"`
Index int `json:"index"`
}
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string, _ []LaunchModel, args []string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
cmd := exec.Command("droid", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (d *Droid) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".factory", "settings.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (d *Droid) Edit(models []LaunchModel) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
settingsPath := filepath.Join(home, ".factory", "settings.json")
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
return err
}
// Read file once, unmarshal twice:
// map preserves unknown fields for writing back (including extra fields in model entries)
settingsMap := make(map[string]any)
var settings droidSettings
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settingsMap); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
settingsMap = updateDroidSettings(settingsMap, settings, models)
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, data, "droid")
}
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []LaunchModel) map[string]any {
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
if rawModels, ok := settingsMap["customModels"].([]any); ok {
for _, raw := range rawModels {
if m, ok := raw.(map[string]any); ok {
if m["apiKey"] != "ollama" {
nonOllamaModels = append(nonOllamaModels, raw)
}
}
}
}
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var newModels []any
var defaultModelID string
for i, model := range models {
maxOutput := 64000
if model.MaxOutputTokens > 0 {
maxOutput = model.MaxOutputTokens
}
modelID := fmt.Sprintf("custom:%s-%d", model.Name, i)
newModels = append(newModels, modelEntry{
Model: model.Name,
DisplayName: model.Name,
BaseURL: envconfig.Host().String() + "/v1",
APIKey: "ollama",
Provider: "generic-chat-completion-api",
MaxOutputTokens: maxOutput,
SupportsImages: model.HasCapability("vision"),
ID: modelID,
Index: i,
})
if i == 0 {
defaultModelID = modelID
}
}
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
// Update session default settings (preserve unknown fields in the nested object)
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
sessionSettings["reasoningEffort"] = "none"
}
settingsMap["sessionDefaultSettings"] = sessionSettings
return settingsMap
}
func (d *Droid) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
var settings droidSettings
if err := json.Unmarshal(data, &settings); err != nil {
return nil
}
var result []string
for _, m := range settings.CustomModels {
if m.APIKey == "ollama" {
result = append(result, m.Model)
}
}
return result
}
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}

1345
cmd/launch/droid_test.go Normal file

File diff suppressed because it is too large Load Diff

679
cmd/launch/hermes.go Normal file
View File

@@ -0,0 +1,679 @@
package launch
import (
"bufio"
"bytes"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
const (
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
hermesProviderName = "Ollama"
hermesProviderKey = "ollama-launch"
hermesLegacyKey = "ollama"
hermesPlaceholderKey = "ollama"
hermesGatewaySetupHint = "hermes gateway setup"
hermesGatewaySetupTitle = "Connect a messaging app now?"
)
var (
hermesGOOS = runtime.GOOS
hermesLookPath = exec.LookPath
hermesCommand = exec.Command
hermesUserHome = os.UserHomeDir
hermesOllamaURL = envconfig.ConnectableHost
)
var hermesMessagingEnvGroups = [][]string{
{"TELEGRAM_BOT_TOKEN"},
{"DISCORD_BOT_TOKEN"},
{"SLACK_BOT_TOKEN"},
{"SIGNAL_ACCOUNT"},
{"EMAIL_ADDRESS"},
{"TWILIO_ACCOUNT_SID"},
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
{"MATTERMOST_TOKEN"},
{"WHATSAPP_PHONE_NUMBER_ID"},
{"DINGTALK_CLIENT_ID"},
{"FEISHU_APP_ID"},
{"WECOM_BOT_ID"},
{"WEIXIN_ACCOUNT_ID"},
{"BLUEBUBBLES_SERVER_URL"},
{"WEBHOOK_ENABLED"},
}
// Hermes is intentionally not an Editor integration: launch owns one primary
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
// switching UX after startup.
type Hermes struct{}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, _ []LaunchModel, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
bin, err := h.binary()
if err != nil {
return err
}
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(bin, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(bin, args...).Run()
}
func (h *Hermes) Paths() []string {
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
} else if !os.IsNotExist(err) {
return err
}
modelSection, _ := cfg["model"].(map[string]any)
if modelSection == nil {
modelSection = make(map[string]any)
}
models := h.listModels(model)
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
// launch writes the minimum provider/default-model settings needed to
// bootstrap Hermes against Ollama. The active provider stays on a
// launch-owned key so /model stays aligned with the launcher-managed entry,
// and the Ollama endpoint lives in providers: so the picker shows one row.
modelSection["provider"] = hermesProviderKey
modelSection["default"] = model
modelSection["base_url"] = hermesBaseURL()
modelSection["api_key"] = hermesPlaceholderKey
cfg["model"] = modelSection
// use Hermes' built-in web toolset for now.
// TODO(parthsareen): move this to using Ollama web search
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data, "hermes")
}
func (h *Hermes) CurrentModel() string {
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
cfg := map[string]any{}
if yaml.Unmarshal(data, &cfg) != nil {
return ""
}
return hermesManagedCurrentModel(cfg, hermesBaseURL())
}
func (h *Hermes) Onboard() error {
return config.MarkIntegrationOnboarded("hermes")
}
func (h *Hermes) RequiresInteractiveOnboarding() bool {
return false
}
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
running, err := h.gatewayRunning()
if err != nil {
return fmt.Errorf("check Hermes gateway status: %w", err)
}
if !running {
return nil
}
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
if err := h.restartGateway(); err != nil {
return fmt.Errorf("restart Hermes gateway: %w", err)
}
fmt.Fprintln(os.Stderr)
return nil
}
func (h *Hermes) installed() bool {
_, err := h.binary()
return err == nil
}
func (h *Hermes) ensureInstalled() error {
if h.installed() {
return nil
}
if hermesGOOS == "windows" {
return hermesWindowsHint()
}
var missing []string
for _, dep := range []string{"bash", "curl", "git"} {
if _, err := hermesLookPath(dep); err != nil {
missing = append(missing, dep)
}
}
if len(missing) > 0 {
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
if err != nil {
return err
}
if !ok {
return fmt.Errorf("hermes installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
return fmt.Errorf("failed to install hermes: %w", err)
}
if !h.installed() {
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
if err != nil {
return []string{defaultModel}
}
models := make([]string, 0, len(resp.Models)+1)
seen := make(map[string]struct{}, len(resp.Models)+1)
add := func(name string) {
name = strings.TrimSpace(name)
if name == "" {
return
}
if _, ok := seen[name]; ok {
return
}
seen[name] = struct{}{}
models = append(models, name)
}
add(defaultModel)
for _, entry := range resp.Models {
add(entry.Name)
}
if len(models) == 0 {
return []string{defaultModel}
}
return models
}
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
}
fallback := filepath.Join(home, ".local", "bin", "hermes")
if _, err := os.Stat(fallback); err == nil {
return fallback, nil
}
return "", fmt.Errorf("hermes is not installed")
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
func hermesEnvPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", ".env"), nil
}
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
return nil
}
if h.messagingConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
if err := runSetup(); err != nil {
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
}
return nil
}
func (h *Hermes) messagingConfigured() bool {
envVars, err := h.gatewayEnvVars()
if err != nil {
return false
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if strings.TrimSpace(envVars[key]) != "" {
return true
}
}
}
return false
}
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
}
case os.IsNotExist(err):
// nothing persisted yet
default:
return nil, err
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
return envVars, nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
return false, err
}
return hermesGatewayStatusRunning(status), nil
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
bin, err := h.binary()
if err != nil {
return "", err
}
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
return string(out), err
}
func (h *Hermes) restartGateway() error {
bin, err := h.binary()
if err != nil {
return err
}
return hermesAttachedCommand(bin, "gateway", "restart").Run()
}
func hermesGatewayStatusRunning(output string) bool {
status := strings.ToLower(output)
switch {
case strings.Contains(status, "gateway is not running"):
return false
case strings.Contains(status, "gateway service is stopped"):
return false
case strings.Contains(status, "gateway service is not loaded"):
return false
case strings.Contains(status, "gateway is running"):
return true
case strings.Contains(status, "gateway service is running"):
return true
case strings.Contains(status, "gateway service is loaded"):
return true
default:
return false
}
}
func hermesParseEnvFile(data []byte) map[string]string {
out := make(map[string]string)
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "export ") {
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
if key == "" {
continue
}
value = strings.TrimSpace(value)
if len(value) >= 2 {
switch {
case value[0] == '"' && value[len(value)-1] == '"':
if unquoted, err := strconv.Unquote(value); err == nil {
value = unquoted
}
case value[0] == '\'' && value[len(value)-1] == '\'':
value = value[1 : len(value)-1]
}
}
out[key] = value
}
return out
}
func hermesOllamaClient() *api.Client {
// Hermes queries the same launch-resolved Ollama host that launch writes
// into config, so model discovery follows the configured endpoint.
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
}
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
providers := hermesUserProviders(cfg["providers"])
entry := hermesManagedProviderEntry(providers)
if entry == nil {
entry = make(map[string]any)
}
entry["name"] = hermesProviderName
entry["api"] = baseURL
entry["default_model"] = model
entry["models"] = hermesStringListAny(models)
providers[hermesProviderKey] = entry
delete(providers, hermesLegacyKey)
cfg["providers"] = providers
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
if len(customProviders) == 0 {
delete(cfg, "custom_providers")
return
}
cfg["custom_providers"] = customProviders
}
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
modelCfg, _ := cfg["model"].(map[string]any)
if modelCfg == nil {
return ""
}
provider, _ := modelCfg["provider"].(string)
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
return ""
}
configBaseURL, _ := modelCfg["base_url"].(string)
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
return ""
}
current, _ := modelCfg["default"].(string)
current = strings.TrimSpace(current)
if current == "" {
return ""
}
providers := hermesUserProviders(cfg["providers"])
entry, _ := providers[hermesProviderKey].(map[string]any)
if entry == nil {
return ""
}
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
return ""
}
apiURL, _ := entry["api"].(string)
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
return ""
}
defaultModel, _ := entry["default_model"].(string)
if strings.TrimSpace(defaultModel) != current {
return ""
}
return current
}
func hermesUserProviders(current any) map[string]any {
switch existing := current.(type) {
case map[string]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
out[key] = value
}
return out
case map[any]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
if s, ok := key.(string); ok {
out[s] = value
}
}
return out
default:
return make(map[string]any)
}
}
func hermesCustomProviders(current any) []any {
switch existing := current.(type) {
case []any:
return append([]any(nil), existing...)
case []map[string]any:
out := make([]any, 0, len(existing))
for _, entry := range existing {
out = append(out, entry)
}
return out
default:
return nil
}
}
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
if entry, _ := providers[key].(map[string]any); entry != nil {
return entry
}
}
return nil
}
func hermesWithoutManagedCustomProviders(current any) []any {
customProviders := hermesCustomProviders(current)
preserved := make([]any, 0, len(customProviders))
for _, item := range customProviders {
entry, _ := item.(map[string]any)
if entry == nil {
preserved = append(preserved, item)
continue
}
if hermesManagedCustomProvider(entry) {
continue
}
preserved = append(preserved, entry)
}
return preserved
}
func hermesHasManagedCustomProvider(current any) bool {
for _, item := range hermesCustomProviders(current) {
entry, _ := item.(map[string]any)
if entry != nil && hermesManagedCustomProvider(entry) {
return true
}
}
return false
}
func hermesManagedCustomProvider(entry map[string]any) bool {
name, _ := entry["name"].(string)
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
}
func hermesNormalizeURL(raw string) string {
return strings.TrimRight(strings.TrimSpace(raw), "/")
}
func hermesStringListAny(models []string) []any {
out := make([]any, 0, len(models))
for _, model := range dedupeModelList(models) {
model = strings.TrimSpace(model)
if model == "" {
continue
}
out = append(out, model)
}
return out
}
func mergeHermesToolsets(current any) any {
added := false
switch existing := current.(type) {
case []any:
out := make([]any, 0, len(existing)+1)
for _, item := range existing {
out = append(out, item)
if s, _ := item.(string); s == "web" {
added = true
}
}
if !added {
out = append(out, "web")
}
return out
case []string:
out := append([]string(nil), existing...)
if !slices.Contains(out, "web") {
out = append(out, "web")
}
asAny := make([]any, 0, len(out))
for _, item := range out {
asAny = append(asAny, item)
}
return asAny
case string:
if strings.TrimSpace(existing) == "" {
return []any{"hermes-cli", "web"}
}
parts := strings.Split(existing, ",")
out := make([]any, 0, len(parts)+1)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if part == "web" {
added = true
}
out = append(out, part)
}
if !added {
out = append(out, "web")
}
return out
default:
return []any{"hermes-cli", "web"}
}
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd
}
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

1109
cmd/launch/hermes_test.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

315
cmd/launch/kimi.go Normal file
View File

@@ -0,0 +1,315 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Kimi implements Runner for Kimi Code CLI integration.
type Kimi struct{}
const (
kimiDefaultModelAlias = "ollama"
kimiDefaultMaxContextSize = 32768
)
var (
kimiGOOS = runtime.GOOS
kimiModelShowTimeout = 5 * time.Second
)
func (k *Kimi) String() string { return "Kimi Code CLI" }
func (k *Kimi) args(config string, extra []string) []string {
args := []string{"--config", config}
args = append(args, extra...)
return args
}
func (k *Kimi) Run(model string, _ []LaunchModel, args []string) error {
if strings.TrimSpace(model) == "" {
return fmt.Errorf("model is required")
}
if err := validateKimiPassthroughArgs(args); err != nil {
return err
}
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
if err != nil {
return fmt.Errorf("failed to build kimi config: %w", err)
}
bin, err := ensureKimiInstalled()
if err != nil {
return err
}
cmd := exec.Command(bin, k.args(config, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findKimiBinary() (string, error) {
if path, err := exec.LookPath("kimi"); err == nil {
return path, nil
}
home, _ := os.UserHomeDir()
var candidates []string
switch kimiGOOS {
case "windows":
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
default:
candidates = append(candidates,
filepath.Join(home, ".local", "bin", "kimi"),
filepath.Join(home, "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
)
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
candidates = append(candidates,
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
)
}
// WSL users can inherit Windows env vars while launching from Linux shells.
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
}
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
}
for _, candidate := range candidates {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate, nil
}
}
return "", fmt.Errorf("kimi binary not found")
}
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
if strings.TrimSpace(dir) == "" {
return candidates
}
return append(candidates,
filepath.Join(dir, "kimi.exe"),
filepath.Join(dir, "kimi.cmd"),
filepath.Join(dir, "kimi.bat"),
)
}
func windowsPathToWSL(path string) string {
trimmed := strings.TrimSpace(path)
if len(trimmed) < 3 || trimmed[1] != ':' {
return ""
}
drive := strings.ToLower(string(trimmed[0]))
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
rest = strings.TrimPrefix(rest, "/")
if rest == "" {
return filepath.Join("/mnt", drive)
}
return filepath.Join("/mnt", drive, rest)
}
func validateKimiPassthroughArgs(args []string) error {
for _, arg := range args {
switch {
case arg == "--config", strings.HasPrefix(arg, "--config="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
case arg == "--model", strings.HasPrefix(arg, "--model="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
case arg == "-m", strings.HasPrefix(arg, "-m="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
}
}
return nil
}
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
cfg := map[string]any{
"default_model": kimiDefaultModelAlias,
"providers": map[string]any{
kimiDefaultModelAlias: map[string]any{
"type": "openai_legacy",
"base_url": envconfig.ConnectableHost().String() + "/v1",
"api_key": "ollama",
},
},
"models": map[string]any{
kimiDefaultModelAlias: map[string]any{
"provider": kimiDefaultModelAlias,
"model": model,
"max_context_size": maxContextSize,
},
},
}
data, err := json.Marshal(cfg)
if err != nil {
return "", err
}
return string(data), nil
}
func resolveKimiMaxContextSize(model string) int {
if l, ok := lookupCloudModelLimit(model); ok {
return l.Context
}
client, err := api.ClientFromEnvironment()
if err != nil {
return kimiDefaultMaxContextSize
}
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
defer cancel()
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
return kimiDefaultMaxContextSize
}
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
return n
}
return kimiDefaultMaxContextSize
}
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
for key, val := range modelInfo {
if !strings.HasSuffix(key, ".context_length") {
continue
}
switch v := val.(type) {
case float64:
if v > 0 {
return int(v), true
}
case int:
if v > 0 {
return v, true
}
case int64:
if v > 0 {
return int(v), true
}
}
}
return 0, false
}
func ensureKimiInstalled() (string, error) {
if path, err := findKimiBinary(); err == nil {
return path, nil
}
if err := checkKimiInstallerDependencies(); err != nil {
return "", err
}
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("kimi installation cancelled")
}
bin, args, err := kimiInstallerCommand(kimiGOOS)
if err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install kimi: %w", err)
}
path, err := findKimiBinary()
if err != nil {
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
return path, nil
}
func checkKimiInstallerDependencies() error {
switch kimiGOOS {
case "windows":
if _, err := exec.LookPath("powershell"); err != nil {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
}
default:
var missing []string
if _, err := exec.LookPath("curl"); err != nil {
missing = append(missing, "curl: https://curl.se/")
}
if _, err := exec.LookPath("bash"); err != nil {
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
}
if len(missing) > 0 {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
}
}
return nil
}
func kimiInstallerCommand(goos string) (string, []string, error) {
switch goos {
case "windows":
return "powershell", []string{
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-Command",
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
}, nil
case "darwin", "linux":
return "bash", []string{
"-c",
"curl -LsSf https://code.kimi.com/install.sh | bash",
}, nil
default:
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
}
}

636
cmd/launch/kimi_test.go Normal file
View File

@@ -0,0 +1,636 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func assertKimiBinPath(t *testing.T, bin string) {
t.Helper()
base := strings.ToLower(filepath.Base(bin))
if !strings.HasPrefix(base, "kimi") {
t.Fatalf("bin = %q, want path to kimi executable", bin)
}
}
func TestKimiIntegration(t *testing.T) {
k := &Kimi{}
t.Run("String", func(t *testing.T) {
if got := k.String(); got != "Kimi Code CLI" {
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = k
})
}
func TestKimiArgs(t *testing.T) {
k := &Kimi{}
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
if !slices.Equal(got, want) {
t.Fatalf("args() = %v, want %v", got, want)
}
}
func TestWindowsPathToWSL(t *testing.T) {
tests := []struct {
name string
in string
want string
valid bool
}{
{
name: "user profile path",
in: `C:\Users\parth`,
want: filepath.Join("/mnt", "c", "Users", "parth"),
valid: true,
},
{
name: "path with trailing slash",
in: `D:\tools\bin\`,
want: filepath.Join("/mnt", "d", "tools", "bin"),
valid: true,
},
{
name: "non windows path",
in: "/home/parth",
valid: false,
},
{
name: "empty",
in: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := windowsPathToWSL(tt.in)
if !tt.valid {
if got != "" {
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
}
return
}
if got != tt.want {
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestFindKimiBinaryFallbacks(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
homeDir := t.TempDir()
setTestHome(t, homeDir)
t.Setenv("PATH", t.TempDir())
kimiGOOS = "linux"
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
t.Run("windows appdata uv bin", func(t *testing.T) {
setTestHome(t, t.TempDir())
t.Setenv("PATH", t.TempDir())
kimiGOOS = "windows"
appDataDir := t.TempDir()
t.Setenv("APPDATA", appDataDir)
t.Setenv("LOCALAPPDATA", "")
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
}
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateKimiPassthroughArgs(tt.args)
if err == nil {
t.Fatalf("expected error for args %v", tt.args)
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
}
})
}
}
func TestBuildKimiInlineConfig(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
if parsed["default_model"] != "ollama" {
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
}
if ollamaProvider["api_key"] != "ollama" {
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
}
models, ok := parsed["models"].(map[string]any)
if !ok {
t.Fatalf("models missing or wrong type: %T", parsed["models"])
}
ollamaModel, ok := models["ollama"].(map[string]any)
if !ok {
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
}
if ollamaModel["provider"] != "ollama" {
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
}
if ollamaModel["model"] != "llama3.2" {
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
}
if ollamaModel["max_context_size"] != float64(65536) {
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
}
}
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
}
}
func TestResolveKimiMaxContextSize(t *testing.T) {
t.Run("uses cloud limit when known", func(t *testing.T) {
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
if got != 262_144 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
}
})
t.Run("uses model show context length for local models", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
got := resolveKimiMaxContextSize("llama3.2")
if got != 131_072 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
}
})
t.Run("falls back to default when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
oldTimeout := kimiModelShowTimeout
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
got := resolveKimiMaxContextSize("llama3.2")
if got != kimiDefaultMaxContextSize {
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
}
})
}
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
k := &Kimi{}
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect install prompt, got %q", prompt)
return false, nil
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
err := k.Run("llama3.2", nil, []string{"--model", "other"})
if err == nil || !strings.Contains(err.Error(), "--model") {
t.Fatalf("expected conflict error mentioning --model, got %v", err)
}
}
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
logPath := filepath.Join(tmpDir, "kimi-args.log")
script := fmt.Sprintf(`#!/bin/sh
for arg in "$@"; do
printf "%%s\n" "$arg" >> %q
done
exit 0
`, logPath)
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake kimi: %v", err)
}
t.Setenv("PATH", tmpDir)
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
k := &Kimi{}
if err := k.Run("llama3.2", nil, []string{"--quiet", "--print"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read args log: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) < 4 {
t.Fatalf("expected at least 4 args, got %v", lines)
}
if lines[0] != "--config" {
t.Fatalf("first arg = %q, want --config", lines[0])
}
var cfg map[string]any
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
t.Fatalf("config arg is not valid JSON: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollamaProvider := providers["ollama"].(map[string]any)
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if lines[2] != "--quiet" || lines[3] != "--print" {
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
}
}
func TestEnsureKimiInstalled(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
t.Run("already installed", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "kimi")
kimiGOOS = runtime.GOOS
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
})
t.Run("missing dependencies", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected missing dependency error, got %v", err)
}
})
t.Run("missing and user declines install", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "curl")
writeFakeBinary(t, tmpDir, "bash")
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
if !strings.Contains(prompt, "Kimi is not installed.") {
t.Fatalf("unexpected prompt: %q", prompt)
}
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
t.Fatalf("expected cancellation error, got %v", err)
}
})
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
installLog := filepath.Join(tmpDir, "bash.log")
kimiPath := filepath.Join(tmpDir, "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "-c" ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, installLog, kimiPath, kimiPath)
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
logData, err := os.ReadFile(installLog)
if err != nil {
t.Fatalf("failed to read install log: %v", err)
}
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
}
})
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
homeDir := t.TempDir()
setTestHome(t, homeDir)
tmpBin := t.TempDir()
t.Setenv("PATH", tmpBin)
kimiGOOS = "linux"
writeFakeBinary(t, tmpBin, "curl")
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
if [ "$1" = "-c" ]; then
/bin/mkdir -p %q
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
if bin != installedKimi {
t.Fatalf("bin = %q, want %q", bin, installedKimi)
}
})
t.Run("install command fails", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
t.Fatalf("expected install failure error, got %v", err)
}
})
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
t.Fatalf("expected PATH guidance error, got %v", err)
}
})
}
func TestKimiInstallerCommand(t *testing.T) {
tests := []struct {
name string
goos string
wantBin string
wantParts []string
wantErr bool
}{
{
name: "linux",
goos: "linux",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "darwin",
goos: "darwin",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "windows",
goos: "windows",
wantBin: "powershell",
wantParts: []string{"-Command", "install.ps1"},
},
{
name: "unsupported",
goos: "freebsd",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bin, args, err := kimiInstallerCommand(tt.goos)
if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}
return
}
if err != nil {
t.Fatalf("kimiInstallerCommand() error = %v", err)
}
if bin != tt.wantBin {
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
}
joined := strings.Join(args, " ")
for _, part := range tt.wantParts {
if !strings.Contains(joined, part) {
t.Fatalf("args %q missing %q", joined, part)
}
}
})
}
}

1495
cmd/launch/launch.go Normal file

File diff suppressed because it is too large Load Diff

3539
cmd/launch/launch_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,201 @@
package launch
import (
"context"
"slices"
"strings"
"sync"
"github.com/ollama/ollama/api"
modelpkg "github.com/ollama/ollama/types/model"
)
// LaunchModel is the model metadata Launch passes to integration config
// writers after resolving selected model names through the per-run inventory.
type LaunchModel struct {
Name string
Remote bool
ToolCapable bool
Capabilities []modelpkg.Capability
ContextLength int
MaxOutputTokens int
EmbeddingLength int
Size int64
Details api.ModelDetails
}
type modelInfo = LaunchModel
// ModelInfo re-exports launcher model inventory details for callers.
type ModelInfo = LaunchModel
func (m LaunchModel) HasCapability(capability modelpkg.Capability) bool {
return slices.Contains(m.Capabilities, capability)
}
func (m LaunchModel) WithCloudLimits() LaunchModel {
if limit, ok := lookupCloudModelLimit(m.Name); ok {
if m.ContextLength <= 0 {
m.ContextLength = limit.Context
}
if m.MaxOutputTokens <= 0 {
m.MaxOutputTokens = limit.Output
}
}
return m
}
type modelInventory struct {
client *api.Client
mu sync.Mutex
loaded bool
models []LaunchModel
err error
}
func newModelInventory(client *api.Client) *modelInventory {
return &modelInventory{client: client}
}
func (i *modelInventory) Load(ctx context.Context) ([]LaunchModel, error) {
return i.load(ctx, false)
}
func (i *modelInventory) Refresh(ctx context.Context) ([]LaunchModel, error) {
return i.load(ctx, true)
}
func (i *modelInventory) load(ctx context.Context, force bool) ([]LaunchModel, error) {
if i == nil || i.client == nil {
return nil, nil
}
i.mu.Lock()
defer i.mu.Unlock()
if i.loaded && !force {
return cloneLaunchModels(i.models), i.err
}
resp, err := i.client.List(ctx)
if err != nil {
i.models = nil
i.err = err
i.loaded = true
return nil, err
}
i.models = make([]LaunchModel, 0, len(resp.Models))
for _, model := range resp.Models {
i.models = append(i.models, launchModelFromListResponse(model))
}
i.err = nil
i.loaded = true
return cloneLaunchModels(i.models), i.err
}
func (i *modelInventory) Resolve(ctx context.Context, names []string) []LaunchModel {
names = dedupeModelList(names)
if len(names) == 0 {
return nil
}
models, err := i.Load(ctx)
if err != nil {
models = nil
}
resolved, localMiss := resolveLaunchModels(names, models)
if localMiss {
if refreshed, err := i.Refresh(ctx); err == nil {
resolved, _ = resolveLaunchModels(names, refreshed)
}
}
return resolved
}
func resolveLaunchModels(names []string, models []LaunchModel) ([]LaunchModel, bool) {
resolved := make([]LaunchModel, 0, len(names))
localMiss := false
for _, name := range names {
if model, ok := findLaunchModel(models, name); ok {
resolved = append(resolved, model.WithCloudLimits())
continue
}
if !isCloudModelName(name) {
localMiss = true
}
resolved = append(resolved, fallbackLaunchModel(name))
}
return resolved, localMiss
}
func launchModelFromListResponse(model api.ListModelResponse) LaunchModel {
return LaunchModel{
Name: model.Name,
Remote: model.RemoteModel != "",
ToolCapable: slices.Contains(model.Capabilities, modelpkg.CapabilityTools),
Capabilities: append([]modelpkg.Capability(nil), model.Capabilities...),
ContextLength: model.Details.ContextLength,
EmbeddingLength: model.Details.EmbeddingLength,
Size: model.Size,
Details: model.Details,
}.WithCloudLimits()
}
func fallbackLaunchModel(name string) LaunchModel {
return LaunchModel{Name: name, Remote: isCloudModelName(name)}.WithCloudLimits()
}
func findLaunchModel(models []LaunchModel, name string) (LaunchModel, bool) {
for _, model := range models {
if launchModelMatches(model.Name, name) {
return cloneLaunchModel(model), true
}
}
return LaunchModel{}, false
}
func launchModelMatches(candidate, name string) bool {
if candidate == name {
return true
}
return strings.TrimSuffix(candidate, ":latest") == name
}
func cloneLaunchModel(model LaunchModel) LaunchModel {
model.Capabilities = append([]modelpkg.Capability(nil), model.Capabilities...)
model.Details.Families = append([]string(nil), model.Details.Families...)
return model
}
func cloneLaunchModels(models []LaunchModel) []LaunchModel {
cloned := make([]LaunchModel, len(models))
for i, model := range models {
cloned[i] = cloneLaunchModel(model)
}
return cloned
}
func launchModelNames(models []LaunchModel) []string {
names := make([]string, 0, len(models))
for _, model := range models {
if model.Name != "" {
names = append(names, model.Name)
}
}
return names
}
func launchModelsFromNames(names []string) []LaunchModel {
models := make([]LaunchModel, 0, len(names))
for _, name := range names {
if name == "" {
continue
}
models = append(models, fallbackLaunchModel(name))
}
return models
}

View File

@@ -0,0 +1,80 @@
package launch
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/ollama/ollama/api"
modelpkg "github.com/ollama/ollama/types/model"
)
func TestModelInventoryResolveRefreshesLocalMiss(t *testing.T) {
calls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" {
http.NotFound(w, r)
return
}
calls++
if calls == 1 {
fmt.Fprint(w, `{"models":[]}`)
return
}
fmt.Fprint(w, `{"models":[{"name":"new-model","size":123,"details":{"context_length":65536,"embedding_length":1024},"capabilities":["vision","tools"]}]}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
inventory := newModelInventory(api.NewClient(u, srv.Client()))
got := inventory.Resolve(context.Background(), []string{"new-model"})
if calls != 2 {
t.Fatalf("List calls = %d, want 2", calls)
}
if len(got) != 1 {
t.Fatalf("Resolve returned %d models, want 1", len(got))
}
if got[0].Name != "new-model" {
t.Fatalf("Name = %q, want new-model", got[0].Name)
}
if got[0].ContextLength != 65_536 || got[0].EmbeddingLength != 1_024 {
t.Fatalf("metadata = context %d embedding %d, want refreshed metadata", got[0].ContextLength, got[0].EmbeddingLength)
}
if !got[0].HasCapability(modelpkg.CapabilityVision) || !got[0].ToolCapable {
t.Fatalf("capabilities = %v toolCapable=%v, want refreshed capabilities", got[0].Capabilities, got[0].ToolCapable)
}
}
func TestModelInventoryResolveDoesNotRefreshCloudMiss(t *testing.T) {
calls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" {
http.NotFound(w, r)
return
}
calls++
fmt.Fprint(w, `{"models":[]}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
inventory := newModelInventory(api.NewClient(u, srv.Client()))
got := inventory.Resolve(context.Background(), []string{"glm-5.1:cloud"})
if calls != 1 {
t.Fatalf("List calls = %d, want 1", calls)
}
if len(got) != 1 {
t.Fatalf("Resolve returned %d models, want 1", len(got))
}
if got[0].Name != "glm-5.1:cloud" || !got[0].Remote {
t.Fatalf("resolved model = %#v, want cloud fallback", got[0])
}
if got[0].ContextLength <= 0 || got[0].MaxOutputTokens <= 0 {
t.Fatalf("cloud limits not applied: %#v", got[0])
}
}

593
cmd/launch/models.go Normal file
View File

@@ -0,0 +1,593 @@
package launch
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"os"
"os/exec"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/format"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/progress"
)
var recommendedModels = []ModelItem{
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true, Details: api.ModelDetails{ContextLength: 262_144}, MaxOutputTokens: 262_144},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true, Details: api.ModelDetails{ContextLength: 262_144}, MaxOutputTokens: 32_768},
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true, Details: api.ModelDetails{ContextLength: 202_752}, MaxOutputTokens: 131_072},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true, Details: api.ModelDetails{ContextLength: 204_800}, MaxOutputTokens: 128_000},
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true, VRAMBytes: 12 * format.GigaByte},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true, VRAMBytes: 14 * format.GigaByte},
}
func displayVRAM(vramBytes int64) string {
if vramBytes <= 0 {
return ""
}
gb := float64(vramBytes) / format.GigaByte
if gb == math.Trunc(gb) {
return fmt.Sprintf("~%.0fGB", gb)
}
return fmt.Sprintf("~%.1fGB", gb)
}
// cloudModelLimit holds context and output token limits for a cloud model.
type cloudModelLimit struct {
Context int
Output int
}
// extraCloudModelLimits maps cloud model base names to token limits for models
// that are not already covered by recommendedModels fallback entries.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var extraCloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
"gemma4:31b": {Context: 262_144, Output: 131_072},
"glm-4.6": {Context: 202_752, Output: 131_072},
"glm-4.7": {Context: 202_752, Output: 131_072},
"glm-5": {Context: 202_752, Output: 131_072},
"glm-5.1": {Context: 202_752, Output: 131_072},
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2.6": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
"qwen3.5": {Context: 262_144, Output: 32_768},
}
var cloudModelLimits = mergeCloudModelLimits(cloudModelLimitsFromRecommendations(recommendedModels), extraCloudModelLimits)
var (
dynamicCloudModelLimitsMu sync.RWMutex
dynamicCloudModelLimits = map[string]cloudModelLimit{}
)
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
dynamicCloudModelLimitsMu.RLock()
l, ok := dynamicCloudModelLimits[base]
dynamicCloudModelLimitsMu.RUnlock()
if ok {
return l, true
}
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
}
return cloudModelLimit{}, false
}
func setDynamicCloudModelLimits(limits map[string]cloudModelLimit) {
dynamicCloudModelLimitsMu.Lock()
defer dynamicCloudModelLimitsMu.Unlock()
if limits == nil {
dynamicCloudModelLimits = map[string]cloudModelLimit{}
return
}
cp := make(map[string]cloudModelLimit, len(limits))
for k, v := range limits {
cp[k] = v
}
dynamicCloudModelLimits = cp
}
func cloudModelLimitsFromRecommendations(recommendations []ModelItem) map[string]cloudModelLimit {
limits := make(map[string]cloudModelLimit, len(recommendations))
for _, rec := range recommendations {
if !isCloudModelName(rec.Name) || rec.Details.ContextLength <= 0 || rec.MaxOutputTokens <= 0 {
continue
}
base, stripped := modelref.StripCloudSourceTag(rec.Name)
if !stripped || base == "" {
continue
}
limits[base] = cloudModelLimit{
Context: rec.Details.ContextLength,
Output: rec.MaxOutputTokens,
}
}
return limits
}
func mergeCloudModelLimits(base map[string]cloudModelLimit, overlay map[string]cloudModelLimit) map[string]cloudModelLimit {
out := make(map[string]cloudModelLimit, len(base)+len(overlay))
for name, limit := range base {
out[name] = limit
}
for name, limit := range overlay {
out[name] = limit
}
return out
}
// missingModelPolicy controls how model-not-found errors should be handled.
type missingModelPolicy int
const (
// missingModelPromptPull prompts the user to download missing local models.
missingModelPromptPull missingModelPolicy = iota
// missingModelAutoPull downloads missing local models without prompting.
missingModelAutoPull
// missingModelFail returns an error for missing local models without prompting.
missingModelFail
)
// OpenBrowser opens the URL in the user's browser.
func OpenBrowser(url string) {
switch runtime.GOOS {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
// Skip on headless systems where no display server is available
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return
}
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
}
}
// ensureAuth ensures the user is signed in before cloud-backed models run.
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
var selectedCloudModels []string
for _, m := range selected {
if cloudModels[m] {
selectedCloudModels = append(selectedCloudModels, m)
}
}
if len(selectedCloudModels) == 0 {
return nil
}
return ensureCloudAuth(ctx, client, strings.Join(selectedCloudModels, ", "))
}
func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string) error {
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
user, err := whoamiWithTimeout(ctx, client)
if err == nil && user != nil && user.Name != "" {
return nil
}
var aErr api.AuthorizationError
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
if err != nil {
return err
}
return fmt.Errorf("%s requires sign in", modelList)
}
if DefaultSignIn != nil {
_, err := DefaultSignIn(modelList, aErr.SigninURL)
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return fmt.Errorf("%s requires sign in", modelList)
}
return nil
}
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
if errors.Is(err, ErrCancelled) {
return ErrCancelled
}
if err != nil {
return err
}
if !yes {
return ErrCancelled
}
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
OpenBrowser(aErr.SigninURL)
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\r\033[K")
return ctx.Err()
case <-ticker.C:
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%10 == 0 {
u, err := whoamiWithTimeout(ctx, client)
if err == nil && u != nil && u.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
return nil
}
}
}
}
}
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
} else {
var statusErr api.StatusError
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
return err
}
}
if isCloudModel {
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
}
return fmt.Errorf("model %q not found", model)
}
switch policy {
case missingModelAutoPull:
return pullMissingModel(ctx, client, model)
case missingModelFail:
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
default:
return confirmAndPull(ctx, client, model)
}
}
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullMissingModel(ctx, client, model)
}
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
if err := pullModel(ctx, client, model, false); err != nil {
return fmt.Errorf("failed to pull %s: %w", model, err)
}
return nil
}
// prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, editor Editor, models []LaunchModel) error {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, launchModelNames(models)); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func prepareManagedSingleIntegration(name string, managed ManagedSingleModel, model string, models []LaunchModel) error {
var err error
if withModels, ok := managed.(ManagedModelListConfigurer); ok {
err = withModels.ConfigureWithModels(model, models)
} else {
err = managed.Configure(model)
}
if err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func prepareManagedAutodiscoveryIntegration(name string, autodiscovery ManagedAutodiscoveryIntegration, model string) error {
if err := autodiscovery.ConfigureAutodiscovery(); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
// buildModelList merges existing models with recommendations for selection UIs.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
return buildModelListWithRecommendations(existing, recommendedModels, preChecked, current)
}
func buildModelListWithRecommendations(existing []modelInfo, recommendations []ModelItem, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
recByName := make(map[string]ModelItem)
for _, rec := range recommendations {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
recByName[rec.Name] = rec
}
for _, m := range existing {
existingModels[m.Name] = true
if m.Remote {
cloudModels[m.Name] = true
hasCloudModel = true
} else {
hasLocalModel = true
}
displayName := strings.TrimSuffix(m.Name, ":latest")
existingModels[displayName] = true
if rec, ok := recByName[displayName]; ok {
items = append(items, modelItemFromInventory(displayName, m, copyModelRecommendationFields(displayName, rec)))
} else {
items = append(items, modelItemFromInventory(displayName, m, ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}))
}
}
for _, rec := range recommendations {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
items = append(items, rec)
if isCloudModelName(rec.Name) {
cloudModels[rec.Name] = true
}
}
checked := make(map[string]bool, len(preChecked))
for _, n := range preChecked {
checked[n] = true
}
if current != "" {
matchedCurrent := false
for _, item := range items {
if item.Name == current {
current = item.Name
matchedCurrent = true
break
}
}
if !matchedCurrent {
for _, item := range items {
if strings.HasPrefix(item.Name, current+":") {
current = item.Name
break
}
}
}
}
if checked[current] {
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
}
notInstalled := make(map[string]bool)
for i := range items {
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
notInstalled[items[i].Name] = true
var parts []string
if items[i].Description != "" {
parts = append(parts, items[i].Description)
}
if vram := displayVRAM(items[i].VRAMBytes); vram != "" {
parts = append(parts, vram)
}
parts = append(parts, "(not downloaded)")
items[i].Description = strings.Join(parts, ", ")
}
}
recRank := make(map[string]int)
for i, rec := range recommendations {
recRank[rec.Name] = i + 1
}
if hasLocalModel || hasCloudModel {
// Keep the Recommended section pinned to recommendation order. Checked
// and default-model priority only apply within the More section.
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
if aRec != bRec {
if aRec {
return -1
}
return 1
}
if aRec && bRec {
return recRank[a.Name] - recRank[b.Name]
}
if ac != bc {
if ac {
return -1
}
return 1
}
// Among checked non-recommended items - put the default first
if ac && !aRec && current != "" {
aCurrent := a.Name == current
bCurrent := b.Name == current
if aCurrent != bCurrent {
if aCurrent {
return -1
}
return 1
}
}
if aNew != bNew {
if aNew {
return 1
}
return -1
}
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
}
return items, preChecked, existingModels, cloudModels
}
func copyModelRecommendationFields(name string, rec ModelItem) ModelItem {
rec.Name = name
rec.Recommended = true
return rec
}
func modelItemFromInventory(name string, info modelInfo, item ModelItem) ModelItem {
item.Name = name
item.ToolCapable = info.ToolCapable
item.Capabilities = slices.Clone(info.Capabilities)
item.Size = info.Size
item.Details = info.Details
return item
}
// isCloudModelName reports whether the model name has an explicit cloud source.
func isCloudModelName(name string) bool {
return modelref.HasExplicitCloudSource(name)
}
// filterCloudModels drops remote-only models from the given inventory.
func filterCloudModels(existing []modelInfo) []modelInfo {
filtered := existing[:0]
for _, m := range existing {
if !m.Remote {
filtered = append(filtered, m)
}
}
return filtered
}
// filterCloudItems removes cloud models from selection items.
func filterCloudItems(items []ModelItem) []ModelItem {
filtered := items[:0]
for _, item := range items {
if !isCloudModelName(item.Name) {
filtered = append(filtered, item)
}
}
return filtered
}
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return false
}
return resp.RemoteModel != ""
}
// cloudStatusDisabled returns whether cloud usage is currently disabled.
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
status, err := client.CloudStatusExperimental(ctx)
if err != nil {
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
return false, false
}
return false, false
}
return status.Cloud.Disabled, true
}
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
// Move the shared pull rendering to a small utility once the package boundary settles.
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.PullRequest{Name: model, Insecure: insecure}
return client.Pull(ctx, &request, fn)
}

83
cmd/launch/models_test.go Normal file
View File

@@ -0,0 +1,83 @@
package launch
import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
modelpkg "github.com/ollama/ollama/types/model"
)
func TestBuildModelList_UsesInventoryMetadataForInstalledModels(t *testing.T) {
existing := []modelInfo{
{
Name: "custom-tools:latest",
ToolCapable: true,
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion, modelpkg.CapabilityTools, modelpkg.CapabilityThinking},
Size: 7500 * format.MegaByte,
Details: api.ModelDetails{
ParameterSize: "8B",
QuantizationLevel: "Q4_K_M",
ContextLength: 131_072,
EmbeddingLength: 4096,
},
},
}
items, _, _, _ := buildModelList(existing, nil, "")
var got ModelItem
for _, item := range items {
if item.Name == "custom-tools" {
got = item
break
}
}
if got.Name == "" {
t.Fatal("custom-tools not found in items")
}
if !got.ToolCapable {
t.Fatal("expected installed model to preserve tool capability from tags metadata")
}
if got.Details.ContextLength != 131_072 {
t.Fatalf("Details.ContextLength = %d, want 131072", got.Details.ContextLength)
}
if got.Size != 7500*format.MegaByte {
t.Fatalf("Size = %d, want %d", got.Size, 7500*format.MegaByte)
}
if got.Description != "" {
t.Fatalf("Description = %q, want empty for installed model without recommendation copy", got.Description)
}
}
func TestBuildModelList_InstalledRecommendedPreservesRecommendationAndMetadata(t *testing.T) {
existing := []modelInfo{
{
Name: "qwen3.5",
ToolCapable: true,
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion, modelpkg.CapabilityTools, modelpkg.CapabilityVision},
Size: 14 * format.GigaByte,
Details: api.ModelDetails{ContextLength: 262_144},
},
}
items, _, _, _ := buildModelList(existing, nil, "")
var got ModelItem
for _, item := range items {
if item.Name == "qwen3.5" {
got = item
break
}
}
if got.Name == "" {
t.Fatal("qwen3.5 not found in items")
}
if !got.Recommended || !got.ToolCapable {
t.Fatalf("recommended/tool metadata = %v/%v, want true/true", got.Recommended, got.ToolCapable)
}
if got.Details.ContextLength != 262_144 {
t.Fatalf("Details.ContextLength = %d, want 262144", got.Details.ContextLength)
}
if got.Description != "Reasoning, coding, and visual understanding locally" {
t.Fatalf("Description = %q, want recommendation description", got.Description)
}
}

991
cmd/launch/openclaw.go Normal file
View File

@@ -0,0 +1,991 @@
package launch
import (
"encoding/json"
"fmt"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
const defaultGatewayPort = 18789
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
var openclawFreshInstall bool
var openclawCanInstallDaemon = canInstallDaemon
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, _ []LaunchModel, args []string) error {
bin, err := ensureOpenclawInstalled()
if err != nil {
return err
}
firstLaunch := !c.onboarded()
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := ConfirmPrompt("I understand the risks. Continue?")
if err != nil {
return err
}
if !ok {
return nil
}
// Ensure the latest version is installed before onboarding so we get
// the newest wizard flags (e.g. --auth-choice ollama).
if !openclawFreshInstall {
update := exec.Command(bin, "update")
update.Env = openclawInstallEnv()
update.Stdout = os.Stdout
update.Stderr = os.Stderr
_ = update.Run() // best-effort; continue even if update fails
}
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
onboardArgs := []string{
"onboard",
"--non-interactive",
"--accept-risk",
"--auth-choice", "ollama",
"--custom-base-url", envconfig.Host().String(),
"--custom-model-id", model,
// Launch owns the first real gateway startup immediately after onboarding,
// so don't let OpenClaw fail the whole first-run flow on a transient
// daemon health probe.
"--skip-health",
"--skip-channels",
"--skip-skills",
}
if openclawCanInstallDaemon() {
onboardArgs = append(onboardArgs, "--install-daemon")
}
cmd := exec.Command(bin, onboardArgs...)
cmd.Env = openclawInstallEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
}
patchDeviceScopes()
}
configureOllamaWebSearch()
// When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow.
if len(args) > 0 {
cleanup := func() {}
if shouldEnsureGatewayForArgs(args) {
cleanupFn, _, _, err := c.ensureGatewayReady(bin)
if err != nil {
return windowsHint(err)
}
if cleanupFn != nil {
cleanup = cleanupFn
}
}
defer cleanup()
cmd := exec.Command(bin, args...)
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(err)
}
return nil
}
if err := c.runChannelSetupPreflight(bin); err != nil {
return err
}
// Keep local pairing scopes up to date before the gateway lifecycle
// (restart/start) regardless of channel preflight branch behavior.
patchDeviceScopes()
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
cleanup, token, port, err := c.ensureGatewayReady(bin)
if err != nil {
return windowsHint(err)
}
defer cleanup()
printOpenclawReady(bin, token, port, firstLaunch)
tuiArgs := []string{"tui"}
if firstLaunch {
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
}
tui := exec.Command(bin, tuiArgs...)
tui.Env = openclawEnv()
tui.Stdin = os.Stdin
tui.Stdout = os.Stdout
tui.Stderr = os.Stderr
if err := tui.Run(); err != nil {
return windowsHint(err)
}
return nil
}
func shouldEnsureGatewayForArgs(args []string) bool {
return len(args) > 0 && args[0] == "tui"
}
func (c *Openclaw) ensureGatewayReady(bin string) (func(), string, int, error) {
token, port := c.gatewayInfo()
addr := fmt.Sprintf("127.0.0.1:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes (model, provider, etc.).
if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv()
if err := restart.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
}
if !waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
}
}
// If the daemon is installed but not currently listening, try to bring it
// up before falling back to a foreground child process.
if openclawCanInstallDaemon() && !portOpen(addr) {
start := exec.Command(bin, "daemon", "start")
start.Env = openclawEnv()
if err := start.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon start failed: %v%s\n", ansiYellow, err, ansiReset)
} else if waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
return func() {}, token, port, nil
}
}
cleanup := func() {}
// If the gateway still isn't running, start it as a background child process.
if !portOpen(addr) {
gw := exec.Command(bin, "gateway", "run", "--force")
gw.Env = openclawEnv()
if err := gw.Start(); err != nil {
return nil, "", 0, fmt.Errorf("failed to start gateway: %w", err)
}
cleanup = func() {
if gw.Process != nil {
_ = gw.Process.Kill()
_ = gw.Wait()
}
}
}
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
if !waitForPort(addr, 30*time.Second) {
cleanup()
return nil, "", 0, fmt.Errorf("gateway did not start on %s", addr)
}
return cleanup, token, port, nil
}
// runChannelSetupPreflight prompts users to connect a messaging channel before
// starting the built-in gateway+TUI flow. In interactive sessions, it loops
// until a channel is configured, unless the user chooses "Set up later".
func (c *Openclaw) runChannelSetupPreflight(bin string) error {
if !isInteractiveSession() {
return nil
}
// --yes is headless; channel setup spawns an interactive picker we can't
// auto-answer, so skip it. Users can run `openclaw channels add` later.
if currentLaunchConfirmPolicy.yes {
return nil
}
for {
if c.channelsConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nYour assistant can message you on WhatsApp, Telegram, Discord, and more.\n\n")
ok, err := ConfirmPromptWithOptions("Connect a channel (messaging app) now?", ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
cmd := exec.Command(bin, "channels", "add")
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw channel setup failed: %w\n\nTry running: %s channels add", err, bin))
}
}
}
// channelsConfigured reports whether local OpenClaw config contains at least
// one meaningfully configured channel entry.
func (c *Openclaw) channelsConfigured() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var cfg map[string]any
if json.Unmarshal(data, &cfg) != nil {
continue
}
channels, _ := cfg["channels"].(map[string]any)
if channels == nil {
return false
}
for key, value := range channels {
if key == "defaults" || key == "modelByChannel" {
continue
}
entry, ok := value.(map[string]any)
if !ok {
continue
}
for entryKey := range entry {
if entryKey != "enabled" {
return true
}
}
}
return false
}
return false
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort
home, err := os.UserHomeDir()
if err != nil {
return "", port
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
continue
}
gw, _ := config["gateway"].(map[string]any)
if p, ok := gw["port"].(float64); ok && p > 0 {
port = int(p)
}
auth, _ := gw["auth"].(map[string]any)
if t, _ := auth["token"].(string); t != "" {
token = t
}
return token, port
}
return "", port
}
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
u := fmt.Sprintf("http://127.0.0.1:%d", port)
if token != "" {
u += "/#token=" + url.QueryEscape(token)
}
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
}
}
// openclawEnv returns the current environment with provider API keys cleared
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
func openclawEnv() []string {
clear := map[string]bool{
"ANTHROPIC_API_KEY": true,
"ANTHROPIC_OAUTH_TOKEN": true,
"OPENAI_API_KEY": true,
"GEMINI_API_KEY": true,
"MISTRAL_API_KEY": true,
"GROQ_API_KEY": true,
"XAI_API_KEY": true,
"OPENROUTER_API_KEY": true,
}
var env []string
for _, e := range os.Environ() {
key, _, _ := strings.Cut(e, "=")
if !clear[key] {
env = append(env, e)
}
}
if _, ok := os.LookupEnv("OPENCLAW_PLUGIN_STAGE_DIR"); !ok {
if dir := openclawPluginStageDir(); dir != "" {
env = append(env, "OPENCLAW_PLUGIN_STAGE_DIR="+dir)
}
}
return env
}
func openclawInstallEnv() []string {
env := openclawEnv()
if _, ok := os.LookupEnv("OPENCLAW_EAGER_BUNDLED_PLUGIN_DEPS"); !ok {
env = append(env, "OPENCLAW_EAGER_BUNDLED_PLUGIN_DEPS=1")
}
return env
}
func openclawPluginStageDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".openclaw", "plugin-runtime-deps")
}
// portOpen checks if a TCP port is currently accepting connections.
func portOpen(addr string) bool {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func waitForPort(addr string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err == nil {
conn.Close()
return true
}
time.Sleep(250 * time.Millisecond)
}
return false
}
func windowsHint(err error) error {
if runtime.GOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\n"+
"OpenClaw runs best on WSL2.\n"+
"Quick setup: wsl --install\n"+
"Guide: https://docs.openclaw.ai/windows", err)
}
// onboarded checks if OpenClaw onboarding wizard was completed
// by looking for the wizard.lastRunAt marker in the config
func (c *Openclaw) onboarded() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
} else {
return false
}
// Check for wizard.lastRunAt marker (set when onboarding completes)
wizard, _ := config["wizard"].(map[string]any)
if wizard == nil {
return false
}
lastRunAt, _ := wizard["lastRunAt"].(string)
return lastRunAt != ""
}
// patchDeviceScopes upgrades the local CLI device's paired operator scopes so
// newer gateway auth baselines (approvedScopes) allow launch+TUI reconnects
// without forcing an interactive re-pair. Only patches the local device,
// not remote ones. Best-effort: silently returns on any error.
func patchDeviceScopes() {
home, err := os.UserHomeDir()
if err != nil {
return
}
deviceID := readLocalDeviceID(home)
if deviceID == "" {
return
}
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var devices map[string]map[string]any
if err := json.Unmarshal(data, &devices); err != nil {
return
}
dev, ok := devices[deviceID]
if !ok {
return
}
required := []string{
"operator.read",
"operator.admin",
"operator.approvals",
"operator.pairing",
}
changed := patchScopes(dev, "scopes", required)
if patchScopes(dev, "approvedScopes", required) {
changed = true
}
if tokens, ok := dev["tokens"].(map[string]any); ok {
for role, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok {
if !isOperatorToken(role, tokenMap) {
continue
}
if patchScopes(tokenMap, "scopes", required) {
changed = true
}
}
}
}
if !changed {
return
}
out, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// readLocalDeviceID reads the local device ID from openclaw's identity file.
func readLocalDeviceID(home string) string {
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
if err != nil {
return ""
}
var auth map[string]any
if err := json.Unmarshal(data, &auth); err != nil {
return ""
}
id, _ := auth["deviceId"].(string)
return id
}
// patchScopes ensures obj[key] contains all required scopes. Returns true if
// any scopes were added.
func patchScopes(obj map[string]any, key string, required []string) bool {
existing, _ := obj[key].([]any)
have := make(map[string]bool, len(existing))
for _, s := range existing {
if str, ok := s.(string); ok {
have[str] = true
}
}
added := false
for _, s := range required {
if !have[s] {
existing = append(existing, s)
added = true
}
}
if added {
obj[key] = existing
}
return added
}
func isOperatorToken(tokenRole string, token map[string]any) bool {
if strings.EqualFold(strings.TrimSpace(tokenRole), "operator") {
return true
}
role, _ := token["role"].(string)
return strings.EqualFold(strings.TrimSpace(role), "operator")
}
// canInstallDaemon reports whether the openclaw daemon can be installed as a
// background service. Returns false on Linux when systemd is absent (e.g.
// containers) so that --install-daemon is omitted and the gateway is started
// as a foreground child process instead. Returns true in all other cases.
func canInstallDaemon() bool {
if runtime.GOOS != "linux" {
return true
}
// /run/systemd/system exists as a directory when systemd is the init system.
// This is absent in most containers.
fi, err := os.Stat("/run/systemd/system")
if err != nil || !fi.IsDir() {
return false
}
// Even when systemd is the init system, user services require a user
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
return os.Getenv("XDG_RUNTIME_DIR") != ""
}
func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return "clawdbot", nil
}
_, npmErr := exec.LookPath("npm")
_, gitErr := exec.LookPath("git")
if npmErr != nil || gitErr != nil {
var missing []string
if npmErr != nil {
missing = append(missing, "npm (Node.js): https://nodejs.org/")
}
if gitErr != nil {
missing = append(missing, "git: https://git-scm.com/")
}
return "", fmt.Errorf("OpenClaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch openclaw", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("openclaw installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
cmd.Env = openclawInstallEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install openclaw: %w", err)
}
if _, err := exec.LookPath("openclaw"); err != nil {
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
openclawFreshInstall = true
return "openclaw", nil
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
if _, err := os.Stat(legacy); err == nil {
return []string{legacy}
}
return nil
}
func (c *Openclaw) Edit(models []LaunchModel) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
// Read into map[string]any to preserve unknown fields
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
} else if data, err := os.ReadFile(legacyPath); err == nil {
_ = json.Unmarshal(data, &config)
}
// Navigate/create: models.providers.ollama (preserving other providers)
modelsSection, _ := config["models"].(map[string]any)
if modelsSection == nil {
modelsSection = make(map[string]any)
}
providers, _ := modelsSection["providers"].(map[string]any)
if providers == nil {
providers = make(map[string]any)
}
ollama, _ := providers["ollama"].(map[string]any)
if ollama == nil {
ollama = make(map[string]any)
}
ollama["baseUrl"] = envconfig.Host().String()
// needed to register provider
ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
existingByID := make(map[string]map[string]any)
for _, m := range existingModels {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
existingByID[id] = entry
}
}
}
var newModels []any
for _, m := range models {
entry, _ := openclawModelConfig(m)
// Merge existing fields (user customizations)
if existing, ok := existingByID[m.Name]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
}
}
}
newModels = append(newModels, entry)
}
ollama["models"] = newModels
providers["ollama"] = ollama
modelsSection["providers"] = providers
config["models"] = modelsSection
// Update agents.defaults.model.primary (preserving other agent settings)
agents, _ := config["agents"].(map[string]any)
if agents == nil {
agents = make(map[string]any)
}
defaults, _ := agents["defaults"].(map[string]any)
if defaults == nil {
defaults = make(map[string]any)
}
modelConfig, _ := defaults["model"].(map[string]any)
if modelConfig == nil {
modelConfig = make(map[string]any)
}
modelConfig["primary"] = "ollama/" + models[0].Name
defaults["model"] = modelConfig
agents["defaults"] = defaults
config["agents"] = agents
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, data, "openclaw"); err != nil {
return err
}
// Clear any per-session model overrides so the new primary takes effect
// immediately rather than being shadowed by a cached modelOverride.
clearSessionModelOverride(models[0].Name)
return nil
}
// clearSessionModelOverride removes per-session model overrides from the main
// agent session so the global primary model takes effect on the next TUI launch.
func clearSessionModelOverride(primary string) {
home, err := os.UserHomeDir()
if err != nil {
return
}
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var sessions map[string]map[string]any
if json.Unmarshal(data, &sessions) != nil {
return
}
changed := false
for _, sess := range sessions {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
}
if model, _ := sess["model"].(string); model != "" && model != primary {
sess["model"] = primary
changed = true
}
}
if !changed {
return
}
out, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// configureOllamaWebSearch keeps launch-managed OpenClaw installs on the
// bundled Ollama web_search provider. Older launch builds installed an
// external openclaw-web-search plugin that added custom ollama_web_search and
// ollama_web_fetch tools. Current OpenClaw versions ship Ollama web_search as
// the bundled "ollama" plugin instead, so we migrate stale config and ensure
// fresh installs select the bundled provider.
func configureOllamaWebSearch() {
home, err := os.UserHomeDir()
if err != nil {
return
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
data, err := os.ReadFile(configPath)
if err != nil {
return
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
return
}
stalePluginConfigured := false
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
plugins = make(map[string]any)
}
entries, _ := plugins["entries"].(map[string]any)
if entries == nil {
entries = make(map[string]any)
}
tools, _ := config["tools"].(map[string]any)
if tools == nil {
tools = make(map[string]any)
}
web, _ := tools["web"].(map[string]any)
if web == nil {
web = make(map[string]any)
}
search, _ := web["search"].(map[string]any)
if search == nil {
search = make(map[string]any)
}
fetch, _ := web["fetch"].(map[string]any)
if fetch == nil {
fetch = make(map[string]any)
}
alsoAllow, _ := tools["alsoAllow"].([]any)
var filteredAlsoAllow []any
for _, v := range alsoAllow {
s, ok := v.(string)
if !ok {
filteredAlsoAllow = append(filteredAlsoAllow, v)
continue
}
if s == "ollama_web_search" || s == "ollama_web_fetch" {
stalePluginConfigured = true
continue
}
filteredAlsoAllow = append(filteredAlsoAllow, v)
}
if len(filteredAlsoAllow) > 0 {
tools["alsoAllow"] = filteredAlsoAllow
} else {
delete(tools, "alsoAllow")
}
if _, ok := entries["openclaw-web-search"]; ok {
delete(entries, "openclaw-web-search")
stalePluginConfigured = true
}
ollamaEntry, _ := entries["ollama"].(map[string]any)
if ollamaEntry == nil {
ollamaEntry = make(map[string]any)
}
ollamaEntry["enabled"] = true
entries["ollama"] = ollamaEntry
plugins["entries"] = entries
if allow, ok := plugins["allow"].([]any); ok {
var nextAllow []any
hasOllama := false
for _, v := range allow {
s, ok := v.(string)
if ok && s == "openclaw-web-search" {
stalePluginConfigured = true
continue
}
if ok && s == "ollama" {
hasOllama = true
}
nextAllow = append(nextAllow, v)
}
if !hasOllama {
nextAllow = append(nextAllow, "ollama")
}
plugins["allow"] = nextAllow
}
if installs, ok := plugins["installs"].(map[string]any); ok {
if _, exists := installs["openclaw-web-search"]; exists {
delete(installs, "openclaw-web-search")
stalePluginConfigured = true
}
if len(installs) > 0 {
plugins["installs"] = installs
} else {
delete(plugins, "installs")
}
}
if stalePluginConfigured || search["provider"] == nil {
search["provider"] = "ollama"
}
if stalePluginConfigured {
fetch["enabled"] = true
}
search["enabled"] = true
web["search"] = search
if len(fetch) > 0 {
web["fetch"] = fetch
}
tools["web"] = web
config["plugins"] = plugins
config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ")
if err != nil {
return
}
_ = os.WriteFile(configPath, out, 0o600)
}
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
// The second return value indicates whether the model is a cloud (remote) model.
func openclawModelConfig(model LaunchModel) (map[string]any, bool) {
entry := map[string]any{
"id": model.Name,
"name": model.Name,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
}
// Set input types based on vision capability
if model.HasCapability("vision") {
entry["input"] = []any{"text", "image"}
}
// Set reasoning based on thinking capability
if model.HasCapability("thinking") {
entry["reasoning"] = true
}
if model.ContextLength > 0 {
entry["contextWindow"] = model.ContextLength
}
if model.MaxOutputTokens > 0 {
entry["maxTokens"] = model.MaxOutputTokens
}
return entry, model.Remote || isCloudModelName(model.Name)
}
func (c *Openclaw) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
if err != nil {
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
if err != nil {
return nil
}
}
modelsSection, _ := config["models"].(map[string]any)
providers, _ := modelsSection["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
modelList, _ := ollama["models"].([]any)
var result []string
for _, m := range modelList {
if entry, ok := m.(map[string]any); ok {
if id, ok := entry["id"].(string); ok {
result = append(result, id)
}
}
}
return result
}

2778
cmd/launch/openclaw_test.go Normal file

File diff suppressed because it is too large Load Diff

294
cmd/launch/opencode.go Normal file
View File

@@ -0,0 +1,294 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// OpenCode implements Runner and Editor for OpenCode integration.
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
// instead of writing to opencode's config files.
type OpenCode struct {
configContent string // JSON config built by Edit, passed to Run via env var
}
func (o *OpenCode) String() string { return "OpenCode" }
// findOpenCode returns the opencode binary path, checking PATH first then the
// curl installer location (~/.opencode/bin) which may not be on PATH yet.
func findOpenCode() (string, bool) {
if p, err := exec.LookPath("opencode"); err == nil {
return p, true
}
home, err := os.UserHomeDir()
if err != nil {
return "", false
}
name := "opencode"
if runtime.GOOS == "windows" {
name = "opencode.exe"
}
fallback := filepath.Join(home, ".opencode", "bin", name)
if _, err := os.Stat(fallback); err == nil {
return fallback, true
}
return "", false
}
func (o *OpenCode) Run(model string, models []LaunchModel, args []string) error {
opencodePath, ok := findOpenCode()
if !ok {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
cmd := exec.Command(opencodePath, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()
if content := o.resolveContent(model, models); content != "" {
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
}
return cmd.Run()
}
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
// Returns content built by Edit if available, otherwise builds from model.json
// with the requested model as primary (e.g. re-launch with saved config).
func (o *OpenCode) resolveContent(model string, models []LaunchModel) string {
if o.configContent != "" {
return o.configContent
}
resolvedModels := resolveOpenCodeRunModels(model, models, readModelJSONModels())
if len(resolvedModels) == 0 {
return ""
}
content, err := buildInlineConfig(resolvedModels[0], resolvedModels)
if err != nil {
return ""
}
return content
}
func resolveOpenCodeRunModels(primary string, models []LaunchModel, stateModels []string) []LaunchModel {
if primary == "" {
return nil
}
resolved := make([]LaunchModel, 0, 1+len(models)+len(stateModels))
appendModel := func(name string) {
if name == "" || hasLaunchModel(resolved, name) {
return
}
if model, ok := findLaunchModel(models, name); ok {
resolved = append(resolved, model)
return
}
resolved = append(resolved, fallbackLaunchModel(name))
}
appendModel(primary)
for _, model := range models {
appendModel(model.Name)
}
for _, model := range stateModels {
appendModel(model)
}
return resolved
}
func hasLaunchModel(models []LaunchModel, name string) bool {
for _, model := range models {
if launchModelMatches(model.Name, name) || launchModelMatches(name, model.Name) {
return true
}
}
return false
}
func (o *OpenCode) Paths() []string {
sp, err := openCodeStatePath()
if err != nil {
return nil
}
if _, err := os.Stat(sp); err == nil {
return []string{sp}
}
return nil
}
// openCodeStatePath returns the path to opencode's model state file.
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
func openCodeStatePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
}
func (o *OpenCode) Edit(models []LaunchModel) error {
modelList := launchModelNames(models)
if len(modelList) == 0 {
return nil
}
content, err := buildInlineConfig(models[0], models)
if err != nil {
return err
}
o.configContent = content
// Write model state file so models appear in OpenCode's model picker
statePath, err := openCodeStatePath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
state := map[string]any{
"recent": []any{},
"favorite": []any{},
"variant": map[string]any{},
}
if data, err := os.ReadFile(statePath); err == nil {
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
}
recent, _ := state["recent"].([]any)
modelSet := make(map[string]bool)
for _, m := range modelList {
modelSet[m] = true
}
// Filter out existing Ollama models we're about to re-add
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
e, ok := entry.(map[string]any)
if !ok || e["providerID"] != "ollama" {
return false
}
modelID, _ := e["modelID"].(string)
return modelSet[modelID]
})
// Prepend models in reverse order so first model ends up first
for _, model := range slices.Backward(modelList) {
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
"providerID": "ollama",
"modelID": model,
}))
}
const maxRecentModels = 10
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
state["recent"] = newRecent
stateData, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(statePath, stateData, "opencode")
}
func (o *OpenCode) Models() []string {
return nil
}
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
// primary is the model to launch with, models is the full list of available models.
func buildInlineConfig(primary LaunchModel, models []LaunchModel) (string, error) {
if primary.Name == "" || len(models) == 0 {
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
}
config := map[string]any{
"$schema": "https://opencode.ai/config.json",
"provider": map[string]any{
"ollama": map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
"models": buildModelEntries(models),
},
},
"model": "ollama/" + primary.Name,
}
data, err := json.Marshal(config)
if err != nil {
return "", err
}
return string(data), nil
}
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
func readModelJSONModels() []string {
statePath, err := openCodeStatePath()
if err != nil {
return nil
}
data, err := os.ReadFile(statePath)
if err != nil {
return nil
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
recent, _ := state["recent"].([]any)
var models []string
for _, entry := range recent {
e, ok := entry.(map[string]any)
if !ok {
continue
}
if e["providerID"] != "ollama" {
continue
}
if id, ok := e["modelID"].(string); ok && id != "" {
models = append(models, id)
}
}
return models
}
func buildModelEntries(modelList []LaunchModel) map[string]any {
models := make(map[string]any)
for _, model := range modelList {
entry := map[string]any{
"name": model.Name,
}
if model.HasCapability("vision") {
entry["modalities"] = map[string]any{
"input": []string{"text", "image"},
"output": []string{"text"},
}
}
if model.ContextLength > 0 || model.MaxOutputTokens > 0 {
limit := make(map[string]any)
if model.ContextLength > 0 {
limit["context"] = model.ContextLength
}
if model.MaxOutputTokens > 0 {
limit["output"] = model.MaxOutputTokens
}
entry["limit"] = limit
}
models[model.Name] = entry
}
return models
}

862
cmd/launch/opencode_test.go Normal file
View File

@@ -0,0 +1,862 @@
package launch
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/types/model"
)
func TestOpenCodeIntegration(t *testing.T) {
o := &OpenCode{}
t.Run("String", func(t *testing.T) {
if got := o.String(); got != "OpenCode" {
t.Errorf("String() = %q, want %q", got, "OpenCode")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = o
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = o
})
}
func TestOpenCodeEdit(t *testing.T) {
t.Run("builds config content with provider", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit(testLaunchModels("llama3.2")); err != nil {
t.Fatal(err)
}
var cfg map[string]any
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
t.Fatalf("configContent is not valid JSON: %v", err)
}
// Verify provider structure
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
if ollama["name"] != "Ollama" {
t.Errorf("provider name = %v, want Ollama", ollama["name"])
}
if ollama["npm"] != "@ai-sdk/openai-compatible" {
t.Errorf("npm = %v, want @ai-sdk/openai-compatible", ollama["npm"])
}
// Verify model exists
models, _ := ollama["models"].(map[string]any)
if models["llama3.2"] == nil {
t.Error("model llama3.2 not found in config content")
}
// Verify default model
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("model = %v, want ollama/llama3.2", cfg["model"])
}
})
t.Run("multiple models", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit(testLaunchModels("llama3.2", "qwen3:32b")); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models["llama3.2"] == nil {
t.Error("model llama3.2 not found")
}
if models["qwen3:32b"] == nil {
t.Error("model qwen3:32b not found")
}
// First model should be the default
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("default model = %v, want ollama/llama3.2", cfg["model"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit(testLaunchModels()); err != nil {
t.Fatal(err)
}
if o.configContent != "" {
t.Errorf("expected empty configContent for no models, got %s", o.configContent)
}
})
t.Run("does not write config files", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
o.Edit(testLaunchModels("llama3.2"))
configDir := filepath.Join(tmpDir, ".config", "opencode")
if _, err := os.Stat(filepath.Join(configDir, "opencode.json")); !os.IsNotExist(err) {
t.Error("opencode.json should not be created")
}
if _, err := os.Stat(filepath.Join(configDir, "opencode.jsonc")); !os.IsNotExist(err) {
t.Error("opencode.jsonc should not be created")
}
})
t.Run("cloud model has limits", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
if err := o.Edit(testLaunchModels("glm-4.7:cloud")); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["glm-4.7:cloud"].(map[string]any)
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model should have limit set")
}
expected := cloudModelLimits["glm-4.7"]
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
})
t.Run("local model has no limits", func(t *testing.T) {
setTestHome(t, t.TempDir())
o := &OpenCode{}
o.Edit(testLaunchModels("llama3.2"))
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["llama3.2"].(map[string]any)
if entry["limit"] != nil {
t.Errorf("local model should not have limit, got %v", entry["limit"])
}
})
t.Run("vision model gets image input modalities", func(t *testing.T) {
models := buildModelEntries([]LaunchModel{{Name: "gemma4:26b", Capabilities: []model.Capability{"vision"}}})
entry, _ := models["gemma4:26b"].(map[string]any)
modalities, _ := entry["modalities"].(map[string]any)
input, _ := modalities["input"].([]string)
output, _ := modalities["output"].([]string)
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
t.Fatalf("modalities.input = %v, want [text image]", input)
}
if len(output) != 1 || output[0] != "text" {
t.Fatalf("modalities.output = %v, want [text]", output)
}
})
}
func TestBuildModelEntries(t *testing.T) {
t.Run("defaults to model name without capabilities", func(t *testing.T) {
models := buildModelEntries(testLaunchModels("llama3.2"))
entry, _ := models["llama3.2"].(map[string]any)
if entry["name"] != "llama3.2" {
t.Fatalf("name = %v, want llama3.2", entry["name"])
}
if _, ok := entry["modalities"]; ok {
t.Fatalf("modalities should not be set without capabilities, got %v", entry["modalities"])
}
})
t.Run("uses context and output limits from metadata", func(t *testing.T) {
models := buildModelEntries([]LaunchModel{{Name: "glm-5:cloud", ContextLength: 202_752, MaxOutputTokens: 131_072}})
entry, _ := models["glm-5:cloud"].(map[string]any)
limit, _ := entry["limit"].(map[string]any)
if limit["context"] != 202_752 || limit["output"] != 131_072 {
t.Fatalf("limit = %v, want context/output", limit)
}
})
}
func TestOpenCodeModels_ReturnsNil(t *testing.T) {
o := &OpenCode{}
if models := o.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
}
func TestOpenCodePaths(t *testing.T) {
t.Run("returns nil when model.json does not exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if paths := o.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns model.json path when it exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
o := &OpenCode{}
paths := o.Paths()
if len(paths) != 1 {
t.Fatalf("Paths() returned %d paths, want 1", len(paths))
}
if paths[0] != filepath.Join(stateDir, "model.json") {
t.Errorf("Paths() = %v, want %v", paths[0], filepath.Join(stateDir, "model.json"))
}
})
}
func TestLookupCloudModelLimit(t *testing.T) {
tests := []struct {
name string
wantOK bool
wantContext int
wantOutput int
}{
{"glm-4.7", false, 0, 0},
{"glm-4.7:cloud", true, 202_752, 131_072},
{"glm-5:cloud", true, 202_752, 131_072},
{"glm-5.1:cloud", true, 202_752, 131_072},
{"gemma4:31b-cloud", true, 262_144, 131_072},
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
{"kimi-k2.5", false, 0, 0},
{"kimi-k2.5:cloud", true, 262_144, 262_144},
{"deepseek-v3.2", false, 0, 0},
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
{"qwen3.5", false, 0, 0},
{"qwen3.5:cloud", true, 262_144, 32_768},
{"qwen3-coder:480b", false, 0, 0},
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
{"llama3.2", false, 0, 0},
{"unknown-model:cloud", false, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l, ok := lookupCloudModelLimit(tt.name)
if ok != tt.wantOK {
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
}
if ok {
if l.Context != tt.wantContext {
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
}
if l.Output != tt.wantOutput {
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
}
}
})
}
}
func TestFindOpenCode(t *testing.T) {
t.Run("fallback to ~/.opencode/bin", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Ensure opencode is not on PATH
t.Setenv("PATH", tmpDir)
// Without the fallback binary, findOpenCode should fail
if _, ok := findOpenCode(); ok {
t.Fatal("findOpenCode should fail when binary is not on PATH or in fallback location")
}
// Create a fake binary at the curl install fallback location
binDir := filepath.Join(tmpDir, ".opencode", "bin")
os.MkdirAll(binDir, 0o755)
name := "opencode"
if runtime.GOOS == "windows" {
name = "opencode.exe"
}
fakeBin := filepath.Join(binDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
// Now findOpenCode should succeed via fallback
path, ok := findOpenCode()
if !ok {
t.Fatal("findOpenCode should succeed with fallback binary")
}
if path != fakeBin {
t.Errorf("findOpenCode = %q, want %q", path, fakeBin)
}
})
}
// Verify that the BackfillsCloudModelLimitOnExistingEntry test from the old
// file-based approach is covered by the new inline config approach.
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
expected := cloudModelLimits["glm-4.7"]
if err := o.Edit(testLaunchModels("glm-4.7:cloud")); err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
entry, _ := models["glm-4.7:cloud"].(map[string]any)
limit, ok := entry["limit"].(map[string]any)
if !ok {
t.Fatal("cloud model limit was not set")
}
if limit["context"] != float64(expected.Context) {
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
}
if limit["output"] != float64(expected.Output) {
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
}
}
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
specialModel := `model-with-"quotes"`
err := o.Edit(testLaunchModels(specialModel))
if err != nil {
t.Fatalf("Edit with special chars failed: %v", err)
}
var cfg map[string]any
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
t.Fatalf("resulting config is invalid JSON: %v", err)
}
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if models[specialModel] == nil {
t.Errorf("model with special chars not found in config")
}
}
func TestReadModelJSONModels(t *testing.T) {
t.Run("reads ollama models from model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
models := readModelJSONModels()
if len(models) != 2 {
t.Fatalf("got %d models, want 2", len(models))
}
if models[0] != "llama3.2" || models[1] != "qwen3:32b" {
t.Errorf("got %v, want [llama3.2 qwen3:32b]", models)
}
})
t.Run("skips non-ollama providers", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
models := readModelJSONModels()
if len(models) != 1 || models[0] != "llama3.2" {
t.Errorf("got %v, want [llama3.2]", models)
}
})
t.Run("returns nil when file does not exist", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
if models := readModelJSONModels(); models != nil {
t.Errorf("got %v, want nil", models)
}
})
t.Run("returns nil for corrupt JSON", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{corrupt`), 0o644)
if models := readModelJSONModels(); models != nil {
t.Errorf("got %v, want nil", models)
}
})
}
func TestOpenCodeResolveContent(t *testing.T) {
t.Run("returns Edit's content when set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if err := o.Edit(testLaunchModels("gemma4")); err != nil {
t.Fatal(err)
}
editContent := o.configContent
// Write a different model.json — should be ignored
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "different-model"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
got := o.resolveContent("gemma4", nil)
if got != editContent {
t.Errorf("resolveContent returned different content than Edit set\ngot: %s\nwant: %s", got, editContent)
}
})
t.Run("falls back to model.json when Edit was not called", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("llama3.2", nil)
if content == "" {
t.Fatal("resolveContent returned empty")
}
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/llama3.2" {
t.Errorf("primary = %v, want ollama/llama3.2", cfg["model"])
}
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
cfgModels, _ := ollama["models"].(map[string]any)
if cfgModels["llama3.2"] == nil || cfgModels["qwen3:32b"] == nil {
t.Errorf("expected both models in config, got %v", cfgModels)
}
})
t.Run("uses requested model as primary even when not first in model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("qwen3:32b", nil)
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/qwen3:32b" {
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
}
})
t.Run("injects requested model when missing from model.json", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("gemma4", nil)
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
cfgModels, _ := ollama["models"].(map[string]any)
if cfgModels["gemma4"] == nil {
t.Error("requested model gemma4 not injected into config")
}
if cfg["model"] != "ollama/gemma4" {
t.Errorf("primary = %v, want ollama/gemma4", cfg["model"])
}
})
t.Run("returns empty when no model.json and no model param", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
o := &OpenCode{}
if got := o.resolveContent("", nil); got != "" {
t.Errorf("resolveContent(\"\") = %q, want empty", got)
}
})
t.Run("uses run model metadata when Edit was not called", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
content := o.resolveContent("gemma4", []LaunchModel{
{
Name: "gemma4",
Capabilities: []model.Capability{model.CapabilityVision},
ContextLength: 65_536,
MaxOutputTokens: 8_192,
},
})
if content == "" {
t.Fatal("resolveContent returned empty")
}
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
cfgModels, _ := ollama["models"].(map[string]any)
entry, _ := cfgModels["gemma4"].(map[string]any)
limit, _ := entry["limit"].(map[string]any)
if limit["context"] != float64(65_536) || limit["output"] != float64(8_192) {
t.Fatalf("limit = %v, want context/output from launch metadata", limit)
}
if _, ok := entry["modalities"].(map[string]any); !ok {
t.Fatalf("modalities should be set from launch metadata, got %v", entry["modalities"])
}
if cfgModels["llama3.2"] == nil {
t.Fatalf("state model missing from fallback config: %v", cfgModels)
}
})
t.Run("does not mutate configContent on fallback", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
state := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(state, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
_ = o.resolveContent("llama3.2", nil)
if o.configContent != "" {
t.Errorf("resolveContent should not mutate configContent, got %q", o.configContent)
}
})
}
func TestBuildInlineConfig(t *testing.T) {
t.Run("returns error for empty primary", func(t *testing.T) {
if _, err := buildInlineConfig(LaunchModel{}, testLaunchModels("llama3.2")); err == nil {
t.Error("expected error for empty primary")
}
})
t.Run("returns error for empty models", func(t *testing.T) {
if _, err := buildInlineConfig(fallbackLaunchModel("llama3.2"), nil); err == nil {
t.Error("expected error for empty models")
}
})
t.Run("primary differs from first model in list", func(t *testing.T) {
content, err := buildInlineConfig(fallbackLaunchModel("qwen3:32b"), testLaunchModels("llama3.2", "qwen3:32b"))
if err != nil {
t.Fatal(err)
}
var cfg map[string]any
json.Unmarshal([]byte(content), &cfg)
if cfg["model"] != "ollama/qwen3:32b" {
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
}
})
}
func TestOpenCodeEdit_PreservesRecentEntries(t *testing.T) {
t.Run("prepends new models to existing recent", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "old-A"},
map[string]any{"providerID": "ollama", "modelID": "old-B"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit(testLaunchModels("new-X")); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
if len(recent) != 3 {
t.Fatalf("expected 3 entries, got %d", len(recent))
}
first, _ := recent[0].(map[string]any)
if first["modelID"] != "new-X" {
t.Errorf("first entry = %v, want new-X", first["modelID"])
}
})
t.Run("prepends multiple new models in order", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "old-A"},
map[string]any{"providerID": "ollama", "modelID": "old-B"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit(testLaunchModels("X", "Y", "Z")); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
want := []string{"X", "Y", "Z", "old-A", "old-B"}
if len(recent) != len(want) {
t.Fatalf("expected %d entries, got %d", len(want), len(recent))
}
for i, w := range want {
e, _ := recent[i].(map[string]any)
if e["modelID"] != w {
t.Errorf("recent[%d] = %v, want %v", i, e["modelID"], w)
}
}
})
t.Run("preserves non-ollama entries", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit(testLaunchModels("qwen3:32b")); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
// Should have: qwen3:32b (new), gpt-4 (preserved openai), llama3.2 (preserved ollama)
var foundOpenAI bool
for _, entry := range recent {
e, _ := entry.(map[string]any)
if e["providerID"] == "openai" && e["modelID"] == "gpt-4" {
foundOpenAI = true
}
}
if !foundOpenAI {
t.Errorf("non-ollama gpt-4 entry was not preserved, got %v", recent)
}
})
t.Run("deduplicates ollama models being re-added", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
initial := map[string]any{
"recent": []any{
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
},
}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
o := &OpenCode{}
if err := o.Edit(testLaunchModels("llama3.2")); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
count := 0
for _, entry := range recent {
e, _ := entry.(map[string]any)
if e["modelID"] == "llama3.2" {
count++
}
}
if count != 1 {
t.Errorf("expected 1 llama3.2 entry, got %d", count)
}
})
t.Run("caps recent list at 10", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
os.MkdirAll(stateDir, 0o755)
// Pre-populate with 9 distinct ollama models
recentEntries := make([]any, 0, 9)
for i := range 9 {
recentEntries = append(recentEntries, map[string]any{
"providerID": "ollama",
"modelID": fmt.Sprintf("old-%d", i),
})
}
initial := map[string]any{"recent": recentEntries}
data, _ := json.MarshalIndent(initial, "", " ")
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
// Add 5 new models — should cap at 10 total
o := &OpenCode{}
if err := o.Edit(testLaunchModels("new-0", "new-1", "new-2", "new-3", "new-4")); err != nil {
t.Fatal(err)
}
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
var state map[string]any
json.Unmarshal(stored, &state)
recent, _ := state["recent"].([]any)
if len(recent) != 10 {
t.Errorf("expected 10 entries (capped), got %d", len(recent))
}
})
}
func TestOpenCodeEdit_BaseURL(t *testing.T) {
o := &OpenCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
// Default OLLAMA_HOST
o.Edit(testLaunchModels("llama3.2"))
var cfg map[string]any
json.Unmarshal([]byte(o.configContent), &cfg)
provider, _ := cfg["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
options, _ := ollama["options"].(map[string]any)
baseURL, _ := options["baseURL"].(string)
if baseURL == "" {
t.Error("baseURL should be set")
}
}

365
cmd/launch/pi.go Normal file
View File

@@ -0,0 +1,365 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
type Pi struct{}
const (
piNpmPackage = "@mariozechner/pi-coding-agent"
piWebSearchSource = "npm:@ollama/pi-web-search"
piWebSearchPkg = "@ollama/pi-web-search"
)
func (p *Pi) String() string { return "Pi" }
func (p *Pi) Run(_ string, _ []LaunchModel, args []string) error {
fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
if err := ensureNpmInstalled(); err != nil {
return err
}
fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
bin, err := ensurePiInstalled()
if err != nil {
return err
}
ensurePiWebSearchPackage(bin)
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func ensureNpmInstalled() error {
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
}
return nil
}
func ensurePiInstalled() (string, error) {
if _, err := exec.LookPath("pi"); err == nil {
return "pi", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
}
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("pi installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install pi: %w", err)
}
if _, err := exec.LookPath("pi"); err != nil {
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
return "pi", nil
}
func ensurePiWebSearchPackage(bin string) {
if !shouldManagePiWebSearch() {
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
installed, err := piPackageInstalled(bin, piWebSearchSource)
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
if !installed {
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "install", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
cmd := exec.Command(bin, "update", piWebSearchSource)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
return
}
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
}
func shouldManagePiWebSearch() bool {
client, err := api.ClientFromEnvironment()
if err != nil {
return true
}
disabled, known := cloudStatusDisabled(context.Background(), client)
if known && disabled {
return false
}
return true
}
func piPackageInstalled(bin, source string) (bool, error) {
cmd := exec.Command(bin, "list")
out, err := cmd.CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if msg == "" {
return false, err
}
return false, fmt.Errorf("%w: %s", err, msg)
}
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, source) {
return true, nil
}
}
return false, nil
}
func (p *Pi) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
var paths []string
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
if _, err := os.Stat(modelsPath); err == nil {
paths = append(paths, modelsPath)
}
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
if _, err := os.Stat(settingsPath); err == nil {
paths = append(paths, settingsPath)
}
return paths
}
func (p *Pi) Edit(models []LaunchModel) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config)
}
providers, ok := config["providers"].(map[string]any)
if !ok {
providers = make(map[string]any)
}
ollama, ok := providers["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"baseUrl": envconfig.Host().String() + "/v1",
"api": "openai-completions",
"apiKey": "ollama",
}
}
existingModels, ok := ollama["models"].([]any)
if !ok {
existingModels = make([]any, 0)
}
// Build set of selected models to track which need to be added
selectedSet := make(map[string]bool, len(models))
for _, m := range models {
selectedSet[m.Name] = true
}
// Build new models list:
// 1. Keep user-managed models (no _launch marker) - untouched
// 2. Keep ollama-managed models (_launch marker) that are still selected,
// except stale cloud entries that should be rebuilt below
// 3. Add new ollama-managed models
var newModels []any
for _, m := range existingModels {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
// User-managed model (no _launch marker) - always preserve
if !isPiOllamaModel(modelObj) {
newModels = append(newModels, m)
} else if selectedSet[id] {
// Rebuild stale managed cloud entries so createConfig refreshes
// the whole entry instead of patching it in place.
if !hasContextWindow(modelObj) {
if _, ok := lookupCloudModelLimit(id); ok {
continue
}
}
newModels = append(newModels, m)
selectedSet[id] = false
}
}
}
}
// Add newly selected models that weren't already in the list
for _, model := range models {
if selectedSet[model.Name] {
newModels = append(newModels, createConfig(model))
}
}
ollama["models"] = newModels
providers["ollama"] = ollama
config["providers"] = providers
configData, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, configData, "pi"); err != nil {
return err
}
// Update settings.json with default provider and model
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
settings := make(map[string]any)
if data, err := os.ReadFile(settingsPath); err == nil {
_ = json.Unmarshal(data, &settings)
}
settings["defaultProvider"] = "ollama"
settings["defaultModel"] = models[0].Name
settingsData, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return err
}
return fileutil.WriteWithBackup(settingsPath, settingsData, "pi")
}
func (p *Pi) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
configPath := filepath.Join(home, ".pi", "agent", "models.json")
config, err := fileutil.ReadJSON(configPath)
if err != nil {
return nil
}
providers, _ := config["providers"].(map[string]any)
ollama, _ := providers["ollama"].(map[string]any)
models, _ := ollama["models"].([]any)
var result []string
for _, m := range models {
if modelObj, ok := m.(map[string]any); ok {
if id, ok := modelObj["id"].(string); ok {
result = append(result, id)
}
}
}
slices.Sort(result)
return result
}
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
func isPiOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
return false
}
func hasContextWindow(cfg map[string]any) bool {
switch v := cfg["contextWindow"].(type) {
case float64:
return v > 0
case int:
return v > 0
case int64:
return v > 0
default:
return false
}
}
// createConfig builds Pi model config with capability detection.
func createConfig(model LaunchModel) map[string]any {
cfg := map[string]any{
"id": model.Name,
"_launch": true,
}
// Set input types based on vision capability
if model.HasCapability("vision") {
cfg["input"] = []string{"text", "image"}
} else {
cfg["input"] = []string{"text"}
}
// Set reasoning based on thinking capability
if model.HasCapability("thinking") {
cfg["reasoning"] = true
}
if model.ContextLength > 0 {
cfg["contextWindow"] = model.ContextLength
}
return cfg
}

1210
cmd/launch/pi_test.go Normal file

File diff suppressed because it is too large Load Diff

51
cmd/launch/poolside.go Normal file
View File

@@ -0,0 +1,51 @@
package launch
import (
"fmt"
"os"
"os/exec"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Poolside implements Runner for Poolside's CLI.
type Poolside struct{}
var poolsideGOOS = runtime.GOOS
func (p *Poolside) String() string { return "Pool" }
func poolsideUnsupportedError() error {
return fmt.Errorf("Warning: Poolside is not currently supported on Windows")
}
func (p *Poolside) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "-m", model)
}
args = append(args, extra...)
return args
}
func (p *Poolside) Run(model string, _ []LaunchModel, args []string) error {
if poolsideGOOS == "windows" {
return poolsideUnsupportedError()
}
bin, err := exec.LookPath("pool")
if err != nil {
return fmt.Errorf("pool is not installed")
}
cmd := exec.Command(bin, p.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"POOLSIDE_STANDALONE_BASE_URL="+envconfig.Host().String()+"/v1",
"POOLSIDE_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -0,0 +1,88 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestPoolsideArgs(t *testing.T) {
p := &Poolside{}
tests := []struct {
name string
model string
extra []string
want []string
}{
{name: "with model", model: "qwen3.5", want: []string{"-m", "qwen3.5"}},
{name: "without model", extra: []string{"session"}, want: []string{"session"}},
{name: "with model and extra args", model: "llama3.2", extra: []string{"--foo", "bar"}, want: []string{"-m", "llama3.2", "--foo", "bar"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := p.args(tt.model, tt.extra)
if !slices.Equal(got, tt.want) {
t.Fatalf("args(%q, %v) = %v, want %v", tt.model, tt.extra, got, tt.want)
}
})
}
}
func TestPoolsideRunSetsOllamaEnv(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "pool.log")
poolPath := filepath.Join(tmpDir, "pool")
script := "#!/bin/sh\n" +
"printf 'base=%s\\nkey=%s\\nargs=%s\\n' \"$POOLSIDE_STANDALONE_BASE_URL\" \"$POOLSIDE_API_KEY\" \"$*\" > \"" + logPath + "\"\n"
if err := os.WriteFile(poolPath, []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake pool binary: %v", err)
}
t.Setenv("PATH", tmpDir)
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
p := &Poolside{}
if err := p.Run("qwen3.5", nil, []string{"session"}); err != nil {
t.Fatalf("Run returned error: %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read pool log: %v", err)
}
got := string(data)
if !strings.Contains(got, "base=http://127.0.0.1:11434/v1") {
t.Fatalf("expected Poolside base URL override in log, got:\n%s", got)
}
if !strings.Contains(got, "key=ollama") {
t.Fatalf("expected Poolside API key override in log, got:\n%s", got)
}
if !strings.Contains(got, "args=-m qwen3.5 session") {
t.Fatalf("expected model and extra args in log, got:\n%s", got)
}
}
func TestPoolsideRunWindowsUnsupported(t *testing.T) {
prev := poolsideGOOS
poolsideGOOS = "windows"
t.Cleanup(func() { poolsideGOOS = prev })
p := &Poolside{}
err := p.Run("kimi-k2.6:cloud", nil, nil)
if err == nil {
t.Fatal("expected Windows unsupported error")
}
if !strings.Contains(err.Error(), "not currently supported on Windows") {
t.Fatalf("expected Windows warning, got %v", err)
}
}

470
cmd/launch/registry.go Normal file
View File

@@ -0,0 +1,470 @@
package launch
import (
"fmt"
"os"
"os/exec"
"slices"
"strings"
)
// IntegrationInstallSpec describes how launcher should detect and guide installation.
type IntegrationInstallSpec struct {
CheckInstalled func() bool
EnsureInstalled func() error
URL string
Command []string
}
// IntegrationSpec is the canonical registry entry for one integration.
type IntegrationSpec struct {
Name string
Runner Runner
Aliases []string
Hidden bool
Description string
Install IntegrationInstallSpec
}
// IntegrationInfo contains display information about a registered integration.
type IntegrationInfo struct {
Name string
DisplayName string
Description string
}
var launcherIntegrationOrder = []string{"claude", "codex-app", "hermes", "openclaw", "opencode", "codex", "copilot", "droid", "pi", "pool"}
var integrationSpecs = []*IntegrationSpec{
{
Name: "claude",
Runner: &Claude{},
Description: "Anthropic's coding tool with subagents",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Claude{}).findPath()
return err == nil
},
URL: "https://code.claude.com/docs/en/quickstart",
},
},
{
Name: "claude-desktop",
Runner: &ClaudeDesktop{},
Aliases: []string{"claude-app"},
Description: "Claude Desktop with Ollama Cloud",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return claudeDesktopInstalled()
},
URL: "https://claude.com/download",
},
},
{
Name: "cline",
Runner: &Cline{},
Description: "Autonomous coding agent with parallel execution",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("cline")
return err == nil
},
Command: []string{"npm", "install", "-g", "cline"},
},
},
{
Name: "codex",
Runner: &Codex{},
Description: "OpenAI's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("codex")
return err == nil
},
URL: "https://developers.openai.com/codex/cli/",
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "codex-app",
Runner: &CodexApp{},
Aliases: []string{"codex-desktop", "codex-gui"},
Description: "An AI agent you can delegate real work to, by OpenAI",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return codexAppInstalled()
},
URL: "https://developers.openai.com/codex/quickstart",
},
},
{
Name: "kimi",
Runner: &Kimi{},
Description: "Moonshot's coding agent for terminal and IDEs",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("kimi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensureKimiInstalled()
return err
},
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
},
},
{
Name: "copilot",
Runner: &Copilot{},
Aliases: []string{"copilot-cli"},
Description: "GitHub's AI coding agent for the terminal",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Copilot{}).findPath()
return err == nil
},
URL: "https://github.com/features/copilot/cli/",
},
},
{
Name: "droid",
Runner: &Droid{},
Description: "Factory's coding agent across terminal and IDEs",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("droid")
return err == nil
},
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
},
},
{
Name: "opencode",
Runner: &OpenCode{},
Description: "Anomaly's open-source coding agent",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, ok := findOpenCode()
return ok
},
URL: "https://opencode.ai",
},
},
{
Name: "openclaw",
Runner: &Openclaw{},
Aliases: []string{"clawdbot", "moltbot"},
Description: "Personal AI with 100+ skills",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
if _, err := exec.LookPath("openclaw"); err == nil {
return true
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return true
}
return false
},
EnsureInstalled: func() error {
_, err := ensureOpenclawInstalled()
return err
},
URL: "https://docs.openclaw.ai",
},
},
{
Name: "pi",
Runner: &Pi{},
Description: "Minimal AI agent toolkit with plugin support",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensurePiInstalled()
return err
},
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
},
},
{
Name: "pool",
Runner: &Poolside{},
Description: "Poolside's software agent for enterprise development",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("pool")
return err == nil
},
URL: "https://github.com/poolsideai/pool",
},
},
{
Name: "hermes",
Runner: &Hermes{},
Description: "Self-improving AI agent built by Nous Research",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&Hermes{}).installed()
},
EnsureInstalled: func() error {
return (&Hermes{}).ensureInstalled()
},
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
},
},
{
Name: "vscode",
Runner: &VSCode{},
Aliases: []string{"code"},
Description: "Microsoft's open-source AI code editor",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&VSCode{}).findBinary() != ""
},
URL: "https://code.visualstudio.com",
},
},
}
var integrationSpecsByName map[string]*IntegrationSpec
func init() {
rebuildIntegrationSpecIndexes()
}
func hyperlink(url, text string) string {
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
}
func rebuildIntegrationSpecIndexes() {
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
canonical := make(map[string]bool, len(integrationSpecs))
for _, spec := range integrationSpecs {
key := strings.ToLower(spec.Name)
if key == "" {
panic("launch: integration spec missing name")
}
if canonical[key] {
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
}
canonical[key] = true
integrationSpecsByName[key] = spec
}
seenAliases := make(map[string]string)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
key := strings.ToLower(alias)
if key == "" {
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
}
if canonical[key] {
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
}
if owner, exists := seenAliases[key]; exists {
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
}
seenAliases[key] = spec.Name
integrationSpecsByName[key] = spec
}
}
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
for _, name := range launcherIntegrationOrder {
key := strings.ToLower(name)
if orderSeen[key] {
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
}
orderSeen[key] = true
spec, ok := integrationSpecsByName[key]
if !ok {
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
}
if spec.Name != key {
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
}
if spec.Hidden {
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
}
}
}
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
spec, ok := integrationSpecsByName[strings.ToLower(name)]
if !ok {
return nil, fmt.Errorf("unknown integration: %s", name)
}
return spec, nil
}
// LookupIntegration resolves a registry name to the canonical key and runner.
func LookupIntegration(name string) (string, Runner, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return "", nil, err
}
return spec.Name, spec.Runner, nil
}
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
func ListVisibleIntegrationSpecs() []IntegrationSpec {
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
for _, spec := range integrationSpecs {
if spec.Hidden {
continue
}
if supported, ok := spec.Runner.(SupportedIntegration); ok && supported.Supported() != nil {
continue
}
if spec.Name == "pool" && poolsideGOOS == "windows" {
continue
}
visible = append(visible, *spec)
}
orderRank := make(map[string]int, len(launcherIntegrationOrder))
for i, name := range launcherIntegrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return -1
}
if bRank > 0 {
return 1
}
return strings.Compare(a.Name, b.Name)
})
return visible
}
// ListIntegrationInfos returns the registered integrations in launcher display order.
func ListIntegrationInfos() []IntegrationInfo {
visible := ListVisibleIntegrationSpecs()
infos := make([]IntegrationInfo, 0, len(visible))
for _, spec := range visible {
infos = append(infos, IntegrationInfo{
Name: spec.Name,
DisplayName: spec.Runner.String(),
Description: spec.Description,
})
}
return infos
}
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
func IntegrationSelectionItems() ([]ModelItem, error) {
visible := ListVisibleIntegrationSpecs()
if len(visible) == 0 {
return nil, fmt.Errorf("no integrations available")
}
items := make([]ModelItem, 0, len(visible))
for _, spec := range visible {
description := spec.Runner.String()
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
}
items = append(items, ModelItem{Name: spec.Name, Description: description})
}
return items, nil
}
// IsIntegrationInstalled checks if an integration binary is installed.
func IsIntegrationInstalled(name string) bool {
integration, err := integrationFor(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
return false
}
return integration.installed
}
// integration is resolved registry metadata used by launcher state and install checks.
// It combines immutable registry spec data with computed runtime traits.
type integration struct {
spec *IntegrationSpec
installed bool
autoInstallable bool
editor bool
installHint string
}
// integrationFor resolves an integration name into the canonical spec plus
// derived launcher/install traits used across registry and launch flows.
func integrationFor(name string) (integration, error) {
spec, err := LookupIntegrationSpec(name)
if err != nil {
return integration{}, err
}
installed := true
if spec.Install.CheckInstalled != nil {
installed = spec.Install.CheckInstalled()
}
_, editor := spec.Runner.(Editor)
hint := ""
if spec.Install.URL != "" {
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
} else if len(spec.Install.Command) > 0 {
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
}
return integration{
spec: spec,
installed: installed,
autoInstallable: spec.Install.EnsureInstalled != nil,
editor: editor,
installHint: hint,
}, nil
}
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
func EnsureIntegrationInstalled(name string, runner Runner) error {
integration, err := integrationFor(name)
if err != nil {
return fmt.Errorf("%s is not installed", runner)
}
if supported, ok := runner.(SupportedIntegration); ok {
if err := supported.Supported(); err != nil {
return err
}
}
if integration.spec.Name == "pool" && poolsideGOOS == "windows" {
return poolsideUnsupportedError()
}
if integration.installed {
return nil
}
if integration.autoInstallable {
return integration.spec.Install.EnsureInstalled()
}
switch {
case integration.spec.Install.URL != "":
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
case len(integration.spec.Install.Command) > 0:
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
default:
return fmt.Errorf("%s is not installed", runner)
}
}

View File

@@ -0,0 +1,21 @@
package launch
import "strings"
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
func OverrideIntegration(name string, runner Runner) func() {
spec, err := LookupIntegrationSpec(name)
if err != nil {
key := strings.ToLower(name)
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
return func() {
delete(integrationSpecsByName, key)
}
}
original := spec.Runner
spec.Runner = runner
return func() {
spec.Runner = original
}
}

View File

@@ -0,0 +1,95 @@
package launch
import (
"os"
"path/filepath"
"testing"
)
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
tests := []struct {
name string
binary string
runner Runner
checkPath func(home string) string
}{
{
name: "droid",
binary: "droid",
runner: &Droid{},
checkPath: func(home string) string {
return filepath.Join(home, ".factory", "settings.json")
},
},
{
name: "opencode",
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".local", "state", "opencode", "model.json")
},
},
{
name: "cline",
binary: "cline",
runner: &Cline{},
checkPath: func(home string) string {
return filepath.Join(home, ".cline", "data", "globalState.json")
},
},
{
name: "pi",
binary: "pi",
runner: &Pi{},
checkPath: func(home string) string {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
{
name: "pool",
binary: "pool",
runner: &Poolside{},
checkPath: func(home string) string {
return filepath.Join(home, ".poolside", "config")
},
},
{
name: "kimi",
binary: "kimi",
runner: &Kimi{},
checkPath: func(home string) string {
return filepath.Join(home, ".kimi", "config.toml")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == "pool" && poolsideGOOS == "windows" {
t.Skip("Poolside is intentionally unsupported on Windows")
}
home := t.TempDir()
setTestHome(t, home)
binDir := t.TempDir()
writeFakeBinary(t, binDir, tt.binary)
if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm")
}
if tt.name == "kimi" {
writeFakeBinary(t, binDir, "curl")
writeFakeBinary(t, binDir, "bash")
}
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)
if err := tt.runner.Run("llama3.2", nil, nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
}
})
}
}

View File

@@ -0,0 +1,131 @@
package launch
import (
"errors"
"fmt"
"os"
"golang.org/x/term"
)
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.
var ErrCancelled = errors.New("cancelled")
// errCancelled is kept as an internal alias for existing call sites.
var errCancelled = ErrCancelled
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
var DefaultConfirmPrompt func(prompt string, options ConfirmOptions) (bool, error)
// ConfirmOptions customizes labels for confirmation prompts.
type ConfirmOptions struct {
YesLabel string
NoLabel string
}
// SingleSelector is a function type for single item selection.
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []SelectionItem, current string) (string, error)
// SingleSelectorWithUpdates is a single item selector that can receive refreshed item state while open.
type SingleSelectorWithUpdates func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []SelectionItem, preChecked []string) ([]string, error)
// MultiSelectorWithUpdates is a multi item selector that can receive refreshed item state while open.
type MultiSelectorWithUpdates func(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error)
// DefaultSingleSelector is the default single-select implementation.
var DefaultSingleSelector SingleSelector
// DefaultSingleSelectorWithUpdates is the default single-select implementation with live updates.
var DefaultSingleSelectorWithUpdates SingleSelectorWithUpdates
// DefaultMultiSelector is the default multi-select implementation.
var DefaultMultiSelector MultiSelector
// DefaultMultiSelectorWithUpdates is the default multi-select implementation with live updates.
var DefaultMultiSelectorWithUpdates MultiSelectorWithUpdates
// DefaultSignIn provides a TUI-based sign-in flow.
// When set, ensureAuth uses it instead of plain text prompts.
// Returns the signed-in username or an error.
var DefaultSignIn func(modelName, signInURL string) (string, error)
// DefaultUpgrade provides a TUI-based upgrade flow.
// Returns the updated plan or an error.
var DefaultUpgrade func(modelName, requiredPlan string) (string, error)
type launchConfirmPolicy struct {
yes bool
requireYesMessage bool
}
var currentLaunchConfirmPolicy launchConfirmPolicy
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
old := currentLaunchConfirmPolicy
currentLaunchConfirmPolicy = policy
return func() {
currentLaunchConfirmPolicy = old
}
}
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
func ConfirmPrompt(prompt string) (bool, error) {
return ConfirmPromptWithOptions(prompt, ConfirmOptions{})
}
// ConfirmPromptWithOptions is the shared confirmation gate for launch flows
// that need custom yes/no labels in interactive UIs.
func ConfirmPromptWithOptions(prompt string, options ConfirmOptions) (bool, error) {
if currentLaunchConfirmPolicy.yes {
return true, nil
}
if currentLaunchConfirmPolicy.requireYesMessage {
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
}
if DefaultConfirmPrompt != nil {
return DefaultConfirmPrompt(prompt, options)
}
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {
if _, err := os.Stdin.Read(buf); err != nil {
return false, err
}
switch buf[0] {
case 'Y', 'y', 13:
fmt.Fprintf(os.Stderr, "yes\r\n")
return true, nil
case 'N', 'n', 27, 3:
fmt.Fprintf(os.Stderr, "no\r\n")
return false, nil
}
}
}

112
cmd/launch/selector_test.go Normal file
View File

@@ -0,0 +1,112 @@
package launch
import (
"strings"
"testing"
)
func TestErrCancelled(t *testing.T) {
t.Run("NotNil", func(t *testing.T) {
if errCancelled == nil {
t.Error("errCancelled should not be nil")
}
})
t.Run("Message", func(t *testing.T) {
if errCancelled.Error() != "cancelled" {
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
}
})
}
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
oldPolicy := currentLaunchConfirmPolicy
oldHook := DefaultConfirmPrompt
t.Cleanup(func() {
currentLaunchConfirmPolicy = oldPolicy
DefaultConfirmPrompt = oldHook
})
currentLaunchConfirmPolicy = launchConfirmPolicy{}
var hookCalls int
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
hookCalls++
return true, nil
}
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
ok, err := ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
}
if !ok {
t.Fatal("expected --yes policy to auto-accept prompt")
}
if hookCalls != 0 {
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
}
restoreInner()
_, err = ConfirmPrompt("test prompt")
if err == nil {
t.Fatal("expected requireYesMessage policy to block prompt")
}
if !strings.Contains(err.Error(), "re-run with --yes") {
t.Fatalf("expected actionable --yes error, got: %v", err)
}
if hookCalls != 0 {
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
}
restoreOuter()
ok, err = ConfirmPrompt("test prompt")
if err != nil {
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
}
if !ok {
t.Fatal("expected hook to return true")
}
if hookCalls != 1 {
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
}
}
func TestConfirmPromptWithOptions_DelegatesToOptionsHook(t *testing.T) {
oldPolicy := currentLaunchConfirmPolicy
oldHook := DefaultConfirmPrompt
t.Cleanup(func() {
currentLaunchConfirmPolicy = oldPolicy
DefaultConfirmPrompt = oldHook
})
currentLaunchConfirmPolicy = launchConfirmPolicy{}
called := false
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
called = true
if prompt != "Connect now?" {
t.Fatalf("unexpected prompt: %q", prompt)
}
if options.YesLabel != "Yes" || options.NoLabel != "Set up later" {
t.Fatalf("unexpected options: %+v", options)
}
return true, nil
}
ok, err := ConfirmPromptWithOptions("Connect now?", ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
t.Fatalf("ConfirmPromptWithOptions() error = %v", err)
}
if !ok {
t.Fatal("expected confirm to return true")
}
if !called {
t.Fatal("expected options hook to be called")
}
}

View File

@@ -0,0 +1,86 @@
package launch
import (
"strings"
"testing"
"github.com/ollama/ollama/cmd/config"
)
var (
integrations map[string]Runner
integrationAliases map[string]bool
integrationOrder = launcherIntegrationOrder
)
func init() {
integrations = buildTestIntegrations()
integrationAliases = buildTestIntegrationAliases()
}
func buildTestIntegrations() map[string]Runner {
result := make(map[string]Runner, len(integrationSpecsByName))
for name, spec := range integrationSpecsByName {
result[strings.ToLower(name)] = spec.Runner
}
return result
}
func buildTestIntegrationAliases() map[string]bool {
result := make(map[string]bool)
for _, spec := range integrationSpecs {
for _, alias := range spec.Aliases {
result[strings.ToLower(alias)] = true
}
}
return result
}
func setTestHome(t *testing.T, dir string) {
t.Helper()
setLaunchTestHome(t, dir)
}
func testLaunchModels(names ...string) []LaunchModel {
return launchModelsFromNames(names)
}
func SaveIntegration(appName string, models []string) error {
return config.SaveIntegration(appName, models)
}
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
return config.LoadIntegration(appName)
}
func SaveAliases(appName string, aliases map[string]string) error {
return config.SaveAliases(appName, aliases)
}
func LastModel() string {
return config.LastModel()
}
func SetLastModel(model string) error {
return config.SetLastModel(model)
}
func LastSelection() string {
return config.LastSelection()
}
func SetLastSelection(selection string) error {
return config.SetLastSelection(selection)
}
func IntegrationModel(appName string) string {
return config.IntegrationModel(appName)
}
func IntegrationModels(appName string) []string {
return config.IntegrationModels(appName)
}
func integrationOnboarded(appName string) error {
return config.MarkIntegrationOnboarded(appName)
}

591
cmd/launch/vscode.go Normal file
View File

@@ -0,0 +1,591 @@
package launch
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
// VSCode implements Runner and Editor for Visual Studio Code integration.
type VSCode struct{}
func (v *VSCode) String() string { return "Visual Studio Code" }
// findBinary returns the path/command to launch VS Code, or "" if not found.
// It checks platform-specific locations only.
func (v *VSCode) findBinary() string {
var candidates []string
switch runtime.GOOS {
case "darwin":
candidates = []string{
"/Applications/Visual Studio Code.app",
}
case "windows":
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
}
default: // linux
candidates = []string{
"/usr/bin/code",
"/snap/bin/code",
}
}
for _, c := range candidates {
if _, err := os.Stat(c); err == nil {
return c
}
}
return ""
}
// IsRunning reports whether VS Code is currently running.
// Each platform uses a pattern specific enough to avoid matching Cursor or
// other VS Code forks.
func (v *VSCode) IsRunning() bool {
switch runtime.GOOS {
case "darwin":
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
return err == nil && len(out) > 0
case "windows":
// Match VS Code by executable path to avoid matching Cursor or other forks.
out, err := exec.Command("powershell", "-NoProfile", "-Command",
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
return err == nil && len(strings.TrimSpace(string(out))) > 0
default:
// Match VS Code specifically by its install path to avoid matching
// Cursor (/cursor/) or other forks.
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
out, err := exec.Command("pgrep", "-f", pattern).Output()
if err == nil && len(out) > 0 {
return true
}
}
return false
}
}
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
// its in-memory state back to the database.
func (v *VSCode) Quit() {
if !v.IsRunning() {
return
}
switch runtime.GOOS {
case "darwin":
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
case "windows":
// Kill VS Code by executable path to avoid killing Cursor or other forks.
_ = exec.Command("powershell", "-NoProfile", "-Command",
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
default:
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
_ = exec.Command("pkill", "-f", pattern).Run()
}
}
// Wait for the process to fully exit and flush its state to disk
// TODO(hoyyeva): update spinner to use bubble tea
spinnerFrames := []string{"|", "/", "-", "\\"}
frame := 0
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for range 150 { // 150 ticks × 200ms = 30s timeout
<-ticker.C
frame++
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
if frame%5 == 0 { // check every ~1s
if !v.IsRunning() {
fmt.Fprintf(os.Stderr, "\r\033[K")
// Give VS Code a moment to finish writing its state DB
time.Sleep(1 * time.Second)
return
}
}
}
fmt.Fprintf(os.Stderr, "\r\033[K")
}
const (
minCopilotChatVersion = "0.41.0"
minVSCodeVersion = "1.113"
)
func (v *VSCode) Run(model string, _ []LaunchModel, args []string) error {
v.checkVSCodeVersion()
v.checkCopilotChatVersion()
// Get all configured models (saved by the launcher framework before Run is called)
models := []string{model}
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
models = cfg.Models
}
// VS Code discovers models from ollama ls. Cloud models that pass Show
// (the server knows about them) but aren't in ls need to be pulled to
// register them so VS Code can find them.
if client, err := api.ClientFromEnvironment(); err == nil {
v.ensureModelsRegistered(context.Background(), client, models)
}
// Warn if the default model doesn't support tool calling
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
hasTools := false
for _, c := range resp.Capabilities {
if c == "tools" {
hasTools = true
break
}
}
if !hasTools {
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
}
}
}
v.printModelAccessTip()
if v.IsRunning() {
restart, err := ConfirmPrompt("Restart VS Code?")
if err != nil {
restart = false
}
if restart {
v.Quit()
if err := v.ShowInModelPicker(models); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
}
v.FocusVSCode()
} else {
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
}
} else {
if err := v.ShowInModelPicker(models); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
}
v.FocusVSCode()
}
return nil
}
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
// can discover them from the Ollama API.
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
listed, err := client.List(ctx)
if err != nil {
return
}
registered := make(map[string]bool, len(listed.Models))
for _, m := range listed.Models {
registered[m.Name] = true
}
for _, model := range models {
if registered[model] {
continue
}
// Also check without :latest suffix
if !strings.Contains(model, ":") && registered[model+":latest"] {
continue
}
if err := pullModel(ctx, client, model, false); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
}
}
}
// FocusVSCode brings VS Code to the foreground.
func (v *VSCode) FocusVSCode() {
binary := v.findBinary()
if binary == "" {
return
}
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
_ = exec.Command("open", "-a", binary).Run()
} else {
_ = exec.Command(binary).Start()
}
}
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
func (v *VSCode) printModelAccessTip() {
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
}
func (v *VSCode) Paths() []string {
if p := v.chatLanguageModelsPath(); fileExists(p) {
return []string{p}
}
return nil
}
func (v *VSCode) Edit(models []LaunchModel) error {
if len(models) == 0 {
return nil
}
// Write chatLanguageModels.json with Ollama vendor entry
clmPath := v.chatLanguageModelsPath()
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
return err
}
var entries []map[string]any
if data, err := os.ReadFile(clmPath); err == nil {
_ = json.Unmarshal(data, &entries)
}
// Remove any existing Ollama entries, preserve others
filtered := make([]map[string]any, 0, len(entries))
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
filtered = append(filtered, entry)
}
}
// Add new Ollama entry
filtered = append(filtered, map[string]any{
"vendor": "ollama",
"name": "Ollama",
"url": envconfig.Host().String(),
})
data, err := json.MarshalIndent(filtered, "", " ")
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(clmPath, data, "vscode"); err != nil {
return err
}
// Clean up legacy settings from older Ollama integrations
v.updateSettings()
return nil
}
func (v *VSCode) Models() []string {
if !v.hasOllamaVendor() {
return nil
}
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
return cfg.Models
}
return nil
}
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
func (v *VSCode) hasOllamaVendor() bool {
data, err := os.ReadFile(v.chatLanguageModelsPath())
if err != nil {
return false
}
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
return false
}
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
return true
}
}
return false
}
func (v *VSCode) chatLanguageModelsPath() string {
return v.vscodePath("chatLanguageModels.json")
}
func (v *VSCode) settingsPath() string {
return v.vscodePath("settings.json")
}
// updateSettings cleans up legacy settings from older Ollama integrations.
func (v *VSCode) updateSettings() {
settingsPath := v.settingsPath()
data, err := os.ReadFile(settingsPath)
if err != nil {
return
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
return
}
changed := false
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
if _, ok := settings[key]; ok {
delete(settings, key)
changed = true
}
}
if !changed {
return
}
updated, err := json.MarshalIndent(settings, "", " ")
if err != nil {
return
}
_ = fileutil.WriteWithBackup(settingsPath, updated, "vscode")
}
func (v *VSCode) statePath() string {
return v.vscodePath("globalStorage", "state.vscdb")
}
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
// Chat model picker. It sets the configured models to true in the picker
// preferences so they appear in the dropdown. Models use the VS Code identifier
// format "ollama/Ollama/<name>".
func (v *VSCode) ShowInModelPicker(models []string) error {
if len(models) == 0 {
return nil
}
dbPath := v.statePath()
needsCreate := !fileExists(dbPath)
if needsCreate {
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
return fmt.Errorf("creating state directory: %w", err)
}
}
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
if err != nil {
return fmt.Errorf("opening state database: %w", err)
}
defer db.Close()
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
if needsCreate {
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
return fmt.Errorf("initializing state database: %w", err)
}
}
// Read existing preferences
prefs := make(map[string]bool)
var prefsJSON string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
}
// Build name→ID map from VS Code's cached model list.
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
nameToID := make(map[string]string)
var cacheJSON string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
var cached []map[string]any
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
for _, entry := range cached {
meta, _ := entry["metadata"].(map[string]any)
if meta == nil {
continue
}
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
name, _ := meta["name"].(string)
id, _ := entry["identifier"].(string)
if name != "" && id != "" {
nameToID[name] = id
}
}
}
}
}
// Ollama config is authoritative: always show configured models,
// hide Ollama models that are no longer in the config.
configuredIDs := make(map[string]bool)
for _, m := range models {
for _, id := range v.modelVSCodeIDs(m, nameToID) {
prefs[id] = true
configuredIDs[id] = true
}
}
for id := range prefs {
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
prefs[id] = false
}
}
data, _ := json.Marshal(prefs)
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
return err
}
return nil
}
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
var ids []string
if id, ok := nameToID[model]; ok {
ids = append(ids, id)
} else if !strings.Contains(model, ":") {
if id, ok := nameToID[model+":latest"]; ok {
ids = append(ids, id)
}
}
ids = append(ids, "ollama/Ollama/"+model)
if !strings.Contains(model, ":") {
ids = append(ids, "ollama/Ollama/"+model+":latest")
}
return ids
}
func (v *VSCode) vscodePath(parts ...string) string {
home, _ := os.UserHomeDir()
var base string
switch runtime.GOOS {
case "darwin":
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
case "windows":
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
default:
base = filepath.Join(home, ".config", "Code", "User")
}
return filepath.Join(append([]string{base}, parts...)...)
}
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
func (v *VSCode) checkVSCodeVersion() {
codeCLI := v.findCodeCLI()
if codeCLI == "" {
return
}
out, err := exec.Command(codeCLI, "--version").Output()
if err != nil {
return
}
// "code --version" outputs: version\ncommit\narch
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) == 0 || lines[0] == "" {
return
}
version := strings.TrimSpace(lines[0])
if compareVersions(version, minVSCodeVersion) < 0 {
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
}
}
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
// missing or older than minCopilotChatVersion.
func (v *VSCode) checkCopilotChatVersion() {
codeCLI := v.findCodeCLI()
if codeCLI == "" {
return
}
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
if err != nil {
return
}
installed, version := parseCopilotChatVersion(string(out))
if !installed {
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
return
}
if compareVersions(version, minCopilotChatVersion) < 0 {
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
}
}
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
// so this resolves to the actual CLI binary inside the bundle.
func (v *VSCode) findCodeCLI() string {
binary := v.findBinary()
if binary == "" {
return ""
}
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
bundleCLI := binary + "/Contents/Resources/app/bin/code"
if _, err := os.Stat(bundleCLI); err == nil {
return bundleCLI
}
return ""
}
return binary
}
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
// extension from "code --list-extensions --show-versions" output.
func parseCopilotChatVersion(output string) (installed bool, version string) {
for _, line := range strings.Split(output, "\n") {
// Format: github.copilot-chat@0.40.1
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
continue
}
parts := strings.SplitN(line, "@", 2)
if len(parts) != 2 {
continue
}
return true, strings.TrimSpace(parts[1])
}
return false, ""
}
// compareVersions compares two dot-separated version strings.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func compareVersions(a, b string) int {
aParts := strings.Split(a, ".")
bParts := strings.Split(b, ".")
maxLen := len(aParts)
if len(bParts) > maxLen {
maxLen = len(bParts)
}
for i := range maxLen {
var aNum, bNum int
if i < len(aParts) {
aNum, _ = strconv.Atoi(aParts[i])
}
if i < len(bParts) {
bNum, _ = strconv.Atoi(bParts[i])
}
if aNum < bNum {
return -1
}
if aNum > bNum {
return 1
}
}
return 0
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}

533
cmd/launch/vscode_test.go Normal file
View File

@@ -0,0 +1,533 @@
package launch
import (
"database/sql"
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
_ "github.com/mattn/go-sqlite3"
"github.com/ollama/ollama/cmd/internal/fileutil"
)
func TestVSCodeIntegration(t *testing.T) {
v := &VSCode{}
t.Run("String", func(t *testing.T) {
if got := v.String(); got != "Visual Studio Code" {
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = v
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = v
})
}
func TestVSCodeEdit(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
tests := []struct {
name string
setup string // initial chatLanguageModels.json content, empty means no file
models []string
validate func(t *testing.T, data []byte)
}{
{
name: "fresh install",
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
assertOllamaVendorConfigured(t, data)
},
},
{
name: "preserve other vendor entries",
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
var entries []map[string]any
json.Unmarshal(data, &entries)
if len(entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(entries))
}
// Check Azure entry preserved
found := false
for _, e := range entries {
if v, _ := e["vendor"].(string); v == "azure" {
found = true
}
}
if !found {
t.Error("azure vendor entry was not preserved")
}
assertOllamaVendorConfigured(t, data)
},
},
{
name: "update existing ollama entry",
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
assertOllamaVendorConfigured(t, data)
},
},
{
name: "empty models is no-op",
setup: `[{"vendor": "azure", "name": "Azure"}]`,
models: []string{},
validate: func(t *testing.T, data []byte) {
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
t.Error("empty models should not modify file")
}
},
},
{
name: "corrupted JSON treated as empty",
setup: `{corrupted json`,
models: []string{"llama3.2"},
validate: func(t *testing.T, data []byte) {
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
t.Errorf("result is not valid JSON: %v", err)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.RemoveAll(filepath.Dir(clmPath))
if tt.setup != "" {
os.MkdirAll(filepath.Dir(clmPath), 0o755)
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
}
if err := v.Edit(launchModelsFromNames(tt.models)); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(clmPath)
tt.validate(t, data)
})
}
}
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
// Create settings.json with old byok setting
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
if err := v.Edit(testLaunchModels("llama3.2")); err != nil {
t.Fatal(err)
}
// Verify old settings were removed
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatal(err)
}
var settings map[string]any
json.Unmarshal(data, &settings)
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
}
if _, ok := settings["ollama.launch.configured"]; ok {
t.Error("ollama.launch.configured should have been removed")
}
if settings["editor.fontSize"] != float64(14) {
t.Error("editor.fontSize should have been preserved")
}
}
func TestVSCodeEdit_CreatesDistinctBackupsForManagedFiles(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
backupDir := fileutil.BackupDir()
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
t.Fatal(err)
}
clmOriginal := `[{"vendor":"ollama","name":"Ollama","url":"http://old:11434"}]`
settingsOriginal := `{"github.copilot.chat.byok.ollamaEndpoint":"http://old:11434","ollama.launch.configured":true,"editor.fontSize":14}`
if err := os.WriteFile(clmPath, []byte(clmOriginal), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(settingsPath, []byte(settingsOriginal), 0o644); err != nil {
t.Fatal(err)
}
if err := v.Edit(testLaunchModels("llama3.2")); err != nil {
t.Fatal(err)
}
assertBackupMatches := func(pattern, want string) {
t.Helper()
backups, err := filepath.Glob(filepath.Join(backupDir, pattern))
if err != nil {
t.Fatalf("glob %q failed: %v", pattern, err)
}
for _, backup := range backups {
data, err := os.ReadFile(backup)
if err == nil && string(data) == want {
return
}
}
t.Fatalf("backup matching %q with expected content not found", pattern)
}
assertBackupMatches(filepath.Join("vscode", "chatLanguageModels.json.*"), clmOriginal)
assertBackupMatches(filepath.Join("vscode", "settings.json.*"), settingsOriginal)
}
func TestVSCodePaths(t *testing.T) {
v := &VSCode{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
t.Run("no file returns nil", func(t *testing.T) {
os.Remove(clmPath)
if paths := v.Paths(); paths != nil {
t.Errorf("expected nil, got %v", paths)
}
})
t.Run("existing file returns path", func(t *testing.T) {
os.MkdirAll(filepath.Dir(clmPath), 0o755)
os.WriteFile(clmPath, []byte(`[]`), 0o644)
if paths := v.Paths(); len(paths) != 1 {
t.Errorf("expected 1 path, got %d", len(paths))
}
})
}
// testVSCodePath returns the expected VS Code config path for the given file in tests.
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
t.Helper()
switch runtime.GOOS {
case "darwin":
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
case "windows":
t.Setenv("APPDATA", tmpDir)
return filepath.Join(tmpDir, "Code", "User", filename)
default:
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
}
}
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
t.Helper()
var entries []map[string]any
if err := json.Unmarshal(data, &entries); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
for _, entry := range entries {
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
if name, _ := entry["name"].(string); name != "Ollama" {
t.Errorf("expected name \"Ollama\", got %q", name)
}
if url, _ := entry["url"].(string); url == "" {
t.Error("url not set")
}
return
}
}
t.Error("no ollama vendor entry found")
}
func TestShowInModelPicker(t *testing.T) {
v := &VSCode{}
// helper to create a state DB with optional seed data
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
t.Helper()
dbDir := filepath.Join(tmpDir, "globalStorage")
os.MkdirAll(dbDir, 0o755)
dbPath := filepath.Join(dbDir, "state.vscdb")
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
t.Fatal(err)
}
defer db.Close()
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
t.Fatal(err)
}
if seedPrefs != nil {
data, _ := json.Marshal(seedPrefs)
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
}
if seedCache != nil {
data, _ := json.Marshal(seedCache)
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
}
return dbPath
}
// helper to read prefs back from DB
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
t.Helper()
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var raw string
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
t.Fatal(err)
}
prefs := make(map[string]bool)
json.Unmarshal([]byte(raw), &prefs)
return prefs
}
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
if runtime.GOOS == "windows" {
t.Setenv("APPDATA", tmpDir)
}
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be shown")
}
if !prefs["ollama/Ollama/llama3.2:latest"] {
t.Error("expected llama3.2:latest to be shown")
}
})
t.Run("configured models are shown", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be shown")
}
if !prefs["ollama/Ollama/qwen3:8b"] {
t.Error("expected qwen3:8b to be shown")
}
})
t.Run("removed models are hidden", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"ollama/Ollama/llama3.2": true,
"ollama/Ollama/llama3.2:latest": true,
"ollama/Ollama/mistral": true,
"ollama/Ollama/mistral:latest": true,
}, nil)
// Only configure llama3.2 — mistral should get hidden
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to stay shown")
}
if prefs["ollama/Ollama/mistral"] {
t.Error("expected mistral to be hidden")
}
if prefs["ollama/Ollama/mistral:latest"] {
t.Error("expected mistral:latest to be hidden")
}
})
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"copilot/gpt-4o": true,
}, nil)
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["copilot/gpt-4o"] {
t.Error("expected copilot/gpt-4o to stay shown")
}
})
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
cache := []map[string]any{
{
"identifier": "ollama/Ollama/4",
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
},
}
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/4"] {
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
}
// Name-based fallback should also be set
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected name-based ID to also be shown")
}
})
t.Run("empty models is no-op", func(t *testing.T) {
err := v.ShowInModelPicker([]string{})
if err != nil {
t.Fatal(err)
}
})
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("XDG_CONFIG_HOME", "")
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
"ollama/Ollama/llama3.2": false,
"ollama/Ollama/llama3.2:latest": false,
}, nil)
// Ollama config is authoritative — should override the hidden state
err := v.ShowInModelPicker([]string{"llama3.2"})
if err != nil {
t.Fatal(err)
}
prefs := readPrefs(t, dbPath)
if !prefs["ollama/Ollama/llama3.2"] {
t.Error("expected llama3.2 to be re-shown")
}
})
}
func TestParseCopilotChatVersion(t *testing.T) {
tests := []struct {
name string
output string
wantInstalled bool
wantVersion string
}{
{
name: "found among other extensions",
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
wantInstalled: true,
wantVersion: "0.40.1",
},
{
name: "only extension",
output: "GitHub.copilot-chat@0.41.0\n",
wantInstalled: true,
wantVersion: "0.41.0",
},
{
name: "not installed",
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
wantInstalled: false,
},
{
name: "empty output",
output: "",
wantInstalled: false,
},
{
name: "case insensitive match",
output: "GitHub.Copilot-Chat@0.39.0\n",
wantInstalled: true,
wantVersion: "0.39.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
installed, version := parseCopilotChatVersion(tt.output)
if installed != tt.wantInstalled {
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
}
if installed && version != tt.wantVersion {
t.Errorf("version = %q, want %q", version, tt.wantVersion)
}
})
}
}
func TestCompareVersions(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"0.40.1", "0.40.1", 0},
{"0.40.2", "0.40.1", 1},
{"0.40.0", "0.40.1", -1},
{"0.41.0", "0.40.1", 1},
{"0.39.9", "0.40.1", -1},
{"1.0.0", "0.40.1", 1},
{"0.40", "0.40.1", -1},
{"0.40.1.1", "0.40.1", 1},
}
for _, tt := range tests {
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
got := compareVersions(tt.a, tt.b)
if got != tt.want {
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
}
})
}
}

15
cmd/runner/main.go Normal file
View File

@@ -0,0 +1,15 @@
package main
import (
"fmt"
"os"
"github.com/ollama/ollama/runner"
)
func main() {
if err := runner.Execute(os.Args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
}

27
cmd/start.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build darwin || windows
package cmd
import (
"context"
"errors"
"time"
"github.com/ollama/ollama/api"
)
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}

33
cmd/start_darwin.go Normal file
View File

@@ -0,0 +1,33 @@
package cmd
import (
"context"
"errors"
"os"
"os/exec"
"regexp"
"github.com/ollama/ollama/api"
)
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return errNotRunning
}
link, err := os.Readlink(exe)
if err != nil {
return errNotRunning
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errNotRunning
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err
}
return waitForServer(ctx, client)
}

14
cmd/start_default.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build !windows && !darwin
package cmd
import (
"context"
"errors"
"github.com/ollama/ollama/api"
)
func startApp(ctx context.Context, client *api.Client) error {
return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
}

112
cmd/start_windows.go Normal file
View File

@@ -0,0 +1,112 @@
package cmd
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"syscall"
"unsafe"
"github.com/ollama/ollama/api"
"golang.org/x/sys/windows"
)
const (
Installer = "OllamaSetup.exe"
)
func startApp(ctx context.Context, client *api.Client) error {
if len(isProcRunning(Installer)) > 0 {
return fmt.Errorf("upgrade in progress...")
}
AppName := "ollama app.exe"
exe, err := os.Executable()
if err != nil {
return err
}
appExe := filepath.Join(filepath.Dir(exe), AppName)
_, err = os.Stat(appExe)
if errors.Is(err, os.ErrNotExist) {
// Try the standard install location
localAppData := os.Getenv("LOCALAPPDATA")
appExe = filepath.Join(localAppData, "Ollama", AppName)
_, err := os.Stat(appExe)
if errors.Is(err, os.ErrNotExist) {
// Finally look in the path
appExe, err = exec.LookPath(AppName)
if err != nil {
return errors.New("could not locate ollama app")
}
}
}
cmd_path := "c:\\Windows\\system32\\cmd.exe"
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
cmd.Stdin = strings.NewReader("")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return fmt.Errorf("unable to start ollama app %w", err)
}
if cmd.Process != nil {
defer cmd.Process.Release() //nolint:errcheck
}
return waitForServer(ctx, client)
}
func isProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
if ret > uint32(len(pids)) {
pids = make([]uint32, ret+10)
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
}
if ret < uint32(len(pids)) {
pids = pids[:ret]
}
var matches []uint32
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}

133
cmd/tui/confirm.go Normal file
View File

@@ -0,0 +1,133 @@
package tui
import (
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/launch"
)
var (
confirmActiveStyle = lipgloss.NewStyle().
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
confirmInactiveStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
)
type confirmModel struct {
prompt string
yesLabel string
noLabel string
yes bool
confirmed bool
cancelled bool
width int
}
type ConfirmOptions = launch.ConfirmOptions
func (m confirmModel) Init() tea.Cmd {
return nil
}
func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "esc":
m.cancelled = true
return m, tea.Quit
case "enter":
m.confirmed = true
return m, tea.Quit
case "left":
m.yes = true
case "right":
m.yes = false
}
}
return m, nil
}
func (m confirmModel) View() string {
if m.confirmed || m.cancelled {
return ""
}
var yesBtn, noBtn string
yesLabel := m.yesLabel
if yesLabel == "" {
yesLabel = "Yes"
}
noLabel := m.noLabel
if noLabel == "" {
noLabel = "No"
}
if m.yes {
yesBtn = confirmActiveStyle.Render(" " + yesLabel + " ")
noBtn = confirmInactiveStyle.Render(" " + noLabel + " ")
} else {
yesBtn = confirmInactiveStyle.Render(" " + yesLabel + " ")
noBtn = confirmActiveStyle.Render(" " + noLabel + " ")
}
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
s += " " + yesBtn + " " + noBtn + "\n\n"
s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
// RunConfirm shows a bubbletea yes/no confirmation prompt.
// Returns true if the user confirmed, false if cancelled.
func RunConfirm(prompt string) (bool, error) {
return RunConfirmWithOptions(prompt, ConfirmOptions{})
}
// RunConfirmWithOptions shows a bubbletea yes/no confirmation prompt with
// optional custom button labels.
func RunConfirmWithOptions(prompt string, options ConfirmOptions) (bool, error) {
yesLabel := options.YesLabel
if yesLabel == "" {
yesLabel = "Yes"
}
noLabel := options.NoLabel
if noLabel == "" {
noLabel = "No"
}
m := confirmModel{
prompt: prompt,
yesLabel: yesLabel,
noLabel: noLabel,
yes: true, // default to yes
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return false, fmt.Errorf("error running confirm: %w", err)
}
fm := finalModel.(confirmModel)
if fm.cancelled {
return false, ErrCancelled
}
return fm.yes, nil
}

224
cmd/tui/confirm_test.go Normal file
View File

@@ -0,0 +1,224 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
)
func TestConfirmModel_DefaultsToYes(t *testing.T) {
m := confirmModel{prompt: "Download test?", yes: true}
if !m.yes {
t.Error("should default to yes")
}
}
func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
got := m.View()
if !strings.Contains(got, "Download qwen3:8b?") {
t.Error("should contain the prompt text")
}
}
func TestConfirmModel_View_ContainsButtons(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
got := m.View()
if !strings.Contains(got, "Yes") {
t.Error("should contain Yes button")
}
if !strings.Contains(got, "No") {
t.Error("should contain No button")
}
}
func TestConfirmModel_View_ContainsCustomButtons(t *testing.T) {
m := confirmModel{
prompt: "Connect a messaging app now?",
yesLabel: "Yes",
noLabel: "Set up later",
yes: true,
}
got := m.View()
if !strings.Contains(got, "Yes") {
t.Error("should contain custom yes button")
}
if !strings.Contains(got, "Set up later") {
t.Error("should contain custom no button")
}
}
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
got := m.View()
if !strings.Contains(got, "enter confirm") {
t.Error("should contain help text")
}
}
func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
m := confirmModel{prompt: "Download?", confirmed: true}
if m.View() != "" {
t.Error("View should return empty string after confirmation")
}
}
func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
m := confirmModel{prompt: "Download?", cancelled: true}
if m.View() != "" {
t.Error("View should return empty string after cancellation")
}
}
func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm := updated.(confirmModel)
if !fm.confirmed {
t.Error("enter should set confirmed=true")
}
if !fm.yes {
t.Error("enter with yes selected should keep yes=true")
}
if cmd == nil {
t.Error("enter should return tea.Quit")
}
}
func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: false}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm := updated.(confirmModel)
if !fm.confirmed {
t.Error("enter should set confirmed=true")
}
if fm.yes {
t.Error("enter with no selected should keep yes=false")
}
if cmd == nil {
t.Error("enter should return tea.Quit")
}
}
func TestConfirmModel_EscCancels(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
fm := updated.(confirmModel)
if !fm.cancelled {
t.Error("esc should set cancelled=true")
}
if cmd == nil {
t.Error("esc should return tea.Quit")
}
}
func TestConfirmModel_CtrlCCancels(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
fm := updated.(confirmModel)
if !fm.cancelled {
t.Error("ctrl+c should set cancelled=true")
}
if cmd == nil {
t.Error("ctrl+c should return tea.Quit")
}
}
func TestConfirmModel_NDoesNothing(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
fm := updated.(confirmModel)
if fm.cancelled {
t.Error("'n' should not cancel")
}
if fm.confirmed {
t.Error("'n' should not confirm")
}
if cmd != nil {
t.Error("'n' should not quit")
}
}
func TestConfirmModel_YDoesNothing(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: false}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
fm := updated.(confirmModel)
if fm.confirmed {
t.Error("'y' should not confirm")
}
if fm.yes {
t.Error("'y' should not change selection")
}
if cmd != nil {
t.Error("'y' should not quit")
}
}
func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
// Right moves to No
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
fm := updated.(confirmModel)
if fm.yes {
t.Error("right should move to No")
}
if fm.confirmed || fm.cancelled {
t.Error("navigation should not confirm or cancel")
}
// Left moves back to Yes
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyLeft})
fm = updated.(confirmModel)
if !fm.yes {
t.Error("left should move to Yes")
}
}
func TestConfirmModel_TabDoesNothing(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
fm := updated.(confirmModel)
if !fm.yes {
t.Error("tab should not change selection")
}
if fm.confirmed || fm.cancelled {
t.Error("tab should not confirm or cancel")
}
}
func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
m := confirmModel{prompt: "Download?"}
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
fm := updated.(confirmModel)
if fm.width != 100 {
t.Errorf("expected width 100, got %d", fm.width)
}
}
func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
m := confirmModel{prompt: "Download?", width: 80}
_, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
if cmd == nil {
t.Error("resize (width already set) should return a command")
}
}
func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
m := confirmModel{prompt: "Download?"}
_, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
if cmd != nil {
t.Error("initial WindowSizeMsg should not return a command")
}
}
func TestConfirmModel_ViewMaxWidth(t *testing.T) {
m := confirmModel{prompt: "Download?", yes: true, width: 40}
got := m.View()
// Just ensure it doesn't panic and returns content
if got == "" {
t.Error("View with width set should still return content")
}
}

993
cmd/tui/selector.go Normal file
View File

@@ -0,0 +1,993 @@
package tui
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/launch"
)
var (
selectorTitleStyle = lipgloss.NewStyle().
Bold(true)
selectorItemStyle = lipgloss.NewStyle().
PaddingLeft(4)
selectorSelectedItemStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
selectorDescStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
selectorDescLineStyle = selectorDescStyle.
PaddingLeft(6)
selectorFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
selectorInputStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
selectorDefaultTagStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
selectorHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
selectorMoreStyle = lipgloss.NewStyle().
PaddingLeft(6).
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
sectionHeaderStyle = lipgloss.NewStyle().
PaddingLeft(2).
Bold(true).
Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
)
const maxSelectorItems = 10
// ErrCancelled is returned when the user cancels the selection.
var ErrCancelled = launch.ErrCancelled
type SelectItem struct {
Name string
Description string
Recommended bool
AvailabilityBadge string
}
type selectorItemsUpdatedMsg struct {
items []SelectItem
}
func waitForSelectorItems(updates <-chan []SelectItem) tea.Cmd {
if updates == nil {
return nil
}
return func() tea.Msg {
items, ok := <-updates
if !ok {
return nil
}
return selectorItemsUpdatedMsg{items: items}
}
}
// ConvertItems converts launch.SelectionItem slice to SelectItem slice.
func ConvertItems(items []launch.SelectionItem) []SelectItem {
out := make([]SelectItem, len(items))
for i, item := range items {
out[i] = SelectItem{
Name: item.Name,
Description: item.Description,
Recommended: item.Recommended,
AvailabilityBadge: item.AvailabilityBadge,
}
}
return out
}
// ReorderItems returns a copy with recommended items first, then non-recommended,
// preserving relative order within each group. This ensures the data order matches
// the visual section layout (Recommended / More).
func ReorderItems(items []SelectItem) []SelectItem {
var rec, other []SelectItem
for _, item := range items {
if item.Recommended {
rec = append(rec, item)
} else {
other = append(other, item)
}
}
return append(rec, other...)
}
// selectorModel is the bubbletea model for single selection.
type selectorModel struct {
title string
items []SelectItem
updates <-chan []SelectItem
filter string
cursor int
scrollOffset int
selected string
cancelled bool
helpText string
width int
}
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
m := selectorModel{
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
m.updateScroll(m.otherStart())
return m
}
func currentItemName(items []SelectItem, cursor int) string {
if cursor < 0 || cursor >= len(items) {
return ""
}
return items[cursor].Name
}
func cursorForItemName(items []SelectItem, name string, fallback int) int {
if len(items) == 0 {
return 0
}
if name != "" {
for i, item := range items {
if item.Name == name {
return i
}
}
}
if fallback < 0 {
return 0
}
if fallback >= len(items) {
return len(items) - 1
}
return fallback
}
func (m selectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
func (m selectorModel) Init() tea.Cmd {
return waitForSelectorItems(m.updates)
}
// otherStart returns the index of the first non-recommended item in the filtered list.
// When filtering, all items scroll together so this returns 0.
func (m selectorModel) otherStart() int {
if m.filter != "" {
return 0
}
filtered := m.filteredItems()
for i, item := range filtered {
if !item.Recommended {
return i
}
}
return len(filtered)
}
// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
filtered := m.filteredItems()
otherStart := m.otherStart()
switch msg.Type {
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
m.updateScroll(otherStart)
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
m.updateScroll(otherStart)
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.updateScroll(otherStart)
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
m.updateScroll(otherStart)
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
// updateScroll adjusts scrollOffset based on cursor position.
// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
// When filtering, it's relative to the full filtered list.
func (m *selectorModel) updateScroll(otherStart int) {
if m.filter != "" {
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
return
}
// Cursor is in recommended section — reset "More" scroll to top
if m.cursor < otherStart {
m.scrollOffset = 0
return
}
// Cursor is in "More" section — scroll relative to others
posInOthers := m.cursor - otherStart
maxOthers := maxSelectorItems - otherStart
if maxOthers < 3 {
maxOthers = 3
}
if posInOthers < m.scrollOffset {
m.scrollOffset = posInOthers
}
if posInOthers >= m.scrollOffset+maxOthers {
m.scrollOffset = posInOthers - maxOthers + 1
}
}
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case selectorItemsUpdatedMsg:
current := currentItemName(m.filteredItems(), m.cursor)
m.items = msg.items
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
m.updateScroll(m.otherStart())
return m, waitForSelectorItems(m.updates)
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyLeft:
m.cancelled = true
return m, tea.Quit
case tea.KeyEnter:
filtered := m.filteredItems()
if len(filtered) > 0 && m.cursor < len(filtered) {
m.selected = filtered[m.cursor].Name
}
return m, tea.Quit
default:
m.updateNavigation(msg)
}
}
return m, nil
}
func cursorItemSuffix(item SelectItem) string {
if item.AvailabilityBadge == "" {
return ""
}
return " " + selectorDefaultTagStyle.Render("("+item.AvailabilityBadge+")")
}
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
// renderContent renders the selector content (title, items, help text) without
// checking the cancelled/selected state. This is used by both View() (standalone mode)
// and by the TUI modal which embeds a selectorModel.
func (m selectorModel) renderContent() string {
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else if m.filter != "" {
s.WriteString(sectionHeaderStyle.Render("Top Results"))
s.WriteString("\n")
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
m.renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
} else {
// Split into pinned recommended and scrollable others
var recItems, otherItems []int
for i, item := range filtered {
if item.Recommended {
recItems = append(recItems, i)
} else {
otherItems = append(otherItems, i)
}
}
// Always render all recommended items (pinned)
if len(recItems) > 0 {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderItem(&s, filtered[idx], idx)
}
}
if len(otherItems) > 0 {
s.WriteString("\n")
s.WriteString(sectionHeaderStyle.Render("More"))
s.WriteString("\n")
maxOthers := maxSelectorItems - len(recItems)
if maxOthers < 3 {
maxOthers = 3
}
displayCount := min(len(otherItems), maxOthers)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(otherItems) {
break
}
m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
}
s.WriteString("\n")
help := "↑/↓ navigate • enter select • ← back"
if m.helpText != "" {
help = m.helpText
}
s.WriteString(selectorHelpStyle.Render(help))
return s.String()
}
func (m selectorModel) View() string {
if m.cancelled || m.selected != "" {
return ""
}
s := m.renderContent()
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
// cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current == "" {
return 0
}
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
for i, item := range items {
if item.Name == current {
return i
}
}
for i, item := range items {
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
return 0
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
return SelectSingleWithUpdates(title, items, current, nil)
}
func SelectSingleWithUpdates(title string, items []SelectItem, current string, updates <-chan []SelectItem) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModelWithCurrent(title, items, current)
m.updates = updates
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(selectorModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.selected, nil
}
// multiSelectorModel is the bubbletea model for multi selection.
type multiSelectorModel struct {
title string
items []SelectItem
updates <-chan []SelectItem
itemIndex map[string]int
filter string
cursor int
scrollOffset int
checked map[int]bool
checkOrder []int
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
m := multiSelectorModel{
title: title,
items: items,
itemIndex: make(map[string]int, len(items)),
checked: make(map[int]bool),
}
for i, item := range items {
m.itemIndex[item.Name] = i
}
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
func (m *multiSelectorModel) rebuildItemIndex() {
m.itemIndex = make(map[string]int, len(m.items))
for i, item := range m.items {
m.itemIndex[item.Name] = i
}
}
func (m *multiSelectorModel) replaceItems(items []SelectItem) {
current := currentItemName(m.filteredItems(), m.cursor)
checkedNames := make([]string, 0, len(m.checkOrder))
for _, idx := range m.checkOrder {
if idx >= 0 && idx < len(m.items) {
checkedNames = append(checkedNames, m.items[idx].Name)
}
}
m.items = items
m.rebuildItemIndex()
m.checked = make(map[int]bool, len(checkedNames))
m.checkOrder = nil
for _, name := range checkedNames {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
m.updateScroll(m.otherStart())
}
func (m multiSelectorModel) filteredItems() []SelectItem {
if m.filter == "" {
return m.items
}
filterLower := strings.ToLower(m.filter)
var result []SelectItem
for _, item := range m.items {
if strings.Contains(strings.ToLower(item.Name), filterLower) {
result = append(result, item)
}
}
return result
}
// otherStart returns the index of the first non-recommended item in the filtered list.
func (m multiSelectorModel) otherStart() int {
if m.filter != "" {
return 0
}
filtered := m.filteredItems()
for i, item := range filtered {
if !item.Recommended {
return i
}
}
return len(filtered)
}
// updateScroll adjusts scrollOffset for section-based scrolling (matches single-select).
func (m *multiSelectorModel) updateScroll(otherStart int) {
if m.filter != "" {
if m.cursor < m.scrollOffset {
m.scrollOffset = m.cursor
}
if m.cursor >= m.scrollOffset+maxSelectorItems {
m.scrollOffset = m.cursor - maxSelectorItems + 1
}
return
}
if m.cursor < otherStart {
m.scrollOffset = 0
return
}
posInOthers := m.cursor - otherStart
maxOthers := maxSelectorItems - otherStart
if maxOthers < 3 {
maxOthers = 3
}
if posInOthers < m.scrollOffset {
m.scrollOffset = posInOthers
}
if posInOthers >= m.scrollOffset+maxOthers {
m.scrollOffset = posInOthers - maxOthers + 1
}
}
func (m *multiSelectorModel) toggleItem() {
filtered := m.filteredItems()
if len(filtered) == 0 || m.cursor >= len(filtered) {
return
}
item := filtered[m.cursor]
origIdx := m.itemIndex[item.Name]
if m.checked[origIdx] {
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
delete(m.checked, origIdx)
for i, idx := range m.checkOrder {
if idx == origIdx {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
if wasDefault {
// When removing the default, pick the nearest checked model above it
// (or below if none above) so default fallback follows list order.
newDefault := -1
for i := origIdx - 1; i >= 0; i-- {
if m.checked[i] {
newDefault = i
break
}
}
if newDefault == -1 {
for i := origIdx + 1; i < len(m.items); i++ {
if m.checked[i] {
newDefault = i
break
}
}
}
if newDefault != -1 {
for i, idx := range m.checkOrder {
if idx == newDefault {
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
break
}
}
m.checkOrder = append(m.checkOrder, newDefault)
}
}
} else {
m.checked[origIdx] = true
m.checkOrder = append(m.checkOrder, origIdx)
}
}
func (m multiSelectorModel) selectedCount() int {
return len(m.checkOrder)
}
func (m multiSelectorModel) Init() tea.Cmd {
return waitForSelectorItems(m.updates)
}
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case selectorItemsUpdatedMsg:
m.replaceItems(msg.items)
return m, waitForSelectorItems(m.updates)
case tea.KeyMsg:
filtered := m.filteredItems()
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyLeft:
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
if m.multi {
m.toggleItem()
}
case tea.KeyUp:
if m.cursor > 0 {
m.cursor--
m.updateScroll(m.otherStart())
}
case tea.KeyDown:
if m.cursor < len(filtered)-1 {
m.cursor++
m.updateScroll(m.otherStart())
}
case tea.KeyPgUp:
m.cursor -= maxSelectorItems
if m.cursor < 0 {
m.cursor = 0
}
m.updateScroll(m.otherStart())
case tea.KeyPgDown:
m.cursor += maxSelectorItems
if m.cursor >= len(filtered) {
m.cursor = len(filtered) - 1
}
m.updateScroll(m.otherStart())
case tea.KeyBackspace:
if len(m.filter) > 0 {
m.filter = m.filter[:len(m.filter)-1]
m.cursor = 0
m.scrollOffset = 0
}
case tea.KeyRunes:
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
if m.multi {
m.toggleItem()
}
} else {
m.filter += string(msg.Runes)
m.cursor = 0
m.scrollOffset = 0
}
}
}
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
var check string
if m.checked[origIdx] {
check = "[x] "
} else {
check = "[ ] "
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name))
s.WriteString(cursorItemSuffix(item))
} else {
s.WriteString(selectorItemStyle.Render(check + item.Name))
}
s.WriteString(suffix)
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) View() string {
if m.cancelled || m.confirmed {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
s.WriteString(" ")
if m.filter == "" {
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
} else {
s.WriteString(selectorInputStyle.Render(m.filter))
}
s.WriteString("\n\n")
filtered := m.filteredItems()
if len(filtered) == 0 {
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
s.WriteString("\n")
} else if m.filter != "" {
// Filtering: flat scroll through all matches
displayCount := min(len(filtered), maxSelectorItems)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(filtered) {
break
}
renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
} else {
// Split into pinned recommended and scrollable others (matches single-select layout)
var recItems, otherItems []int
for i, item := range filtered {
if item.Recommended {
recItems = append(recItems, i)
} else {
otherItems = append(otherItems, i)
}
}
// Always render all recommended items (pinned)
if len(recItems) > 0 {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
renderItem(&s, filtered[idx], idx)
}
}
if len(otherItems) > 0 {
s.WriteString("\n")
s.WriteString(sectionHeaderStyle.Render("More"))
s.WriteString("\n")
maxOthers := maxSelectorItems - len(recItems)
if maxOthers < 3 {
maxOthers = 3
}
displayCount := min(len(otherItems), maxOthers)
for i := range displayCount {
idx := m.scrollOffset + i
if idx >= len(otherItems) {
break
}
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
s.WriteString("\n")
}
}
}
s.WriteString("\n")
count := m.selectedCount()
if !m.multi {
if count > 0 {
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press tab to edit", count)))
s.WriteString("\n\n")
}
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
} else {
if count == 0 {
s.WriteString(sectionHeaderStyle.Render("Select at least one model."))
} else {
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
}
result := s.String()
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
}
return result
}
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
return SelectMultipleWithUpdates(title, items, preChecked, nil)
}
func SelectMultipleWithUpdates(title string, items []SelectItem, preChecked []string, updates <-chan []SelectItem) ([]string, error) {
if len(items) == 0 {
return nil, fmt.Errorf("no items to select from")
}
m := newMultiSelectorModel(title, items, preChecked)
m.updates = updates
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return nil, fmt.Errorf("error running selector: %w", err)
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled || !fm.confirmed {
return nil, ErrCancelled
}
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
}
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
for _, idx := range fm.checkOrder {
if idx != last {
result = append(result, fm.items[idx].Name)
}
}
return result, nil
}

1024
cmd/tui/selector_test.go Normal file

File diff suppressed because it is too large Load Diff

340
cmd/tui/signin.go Normal file
View File

@@ -0,0 +1,340 @@
package tui
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/launch"
)
type signInTickMsg struct{}
type signInCheckMsg struct {
signedIn bool
userName string
}
type upgradeTickMsg struct{}
type upgradeCheckMsg struct {
upgraded bool
plan string
err error
}
type signInModel struct {
modelName string
signInURL string
spinner int
width int
userName string
cancelled bool
}
type upgradeModel struct {
modelName string
requiredPlan string
spinner int
width int
openNow bool
polling bool
plan string
cancelled bool
err error
}
func (m signInModel) Init() tea.Cmd {
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
}
func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
}
case signInTickMsg:
m.spinner++
if m.spinner%5 == 0 {
return m, tea.Batch(
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
}),
checkSignIn,
)
}
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return signInTickMsg{}
})
case signInCheckMsg:
if msg.signedIn {
m.userName = msg.userName
return m, tea.Quit
}
}
return m, nil
}
func (m signInModel) View() string {
if m.userName != "" {
return ""
}
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
}
func (m upgradeModel) Init() tea.Cmd {
if m.polling {
return upgradeTickCmd()
}
return nil
}
func (m upgradeModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
wasSet := m.width > 0
m.width = msg.Width
if wasSet {
return m, tea.EnterAltScreen
}
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
m.cancelled = true
return m, tea.Quit
case tea.KeyLeft:
if !m.polling {
m.openNow = true
}
case tea.KeyRight:
if !m.polling {
m.openNow = false
}
case tea.KeyEnter:
if !m.polling {
if !m.openNow {
m.cancelled = true
return m, tea.Quit
}
launch.OpenBrowser(launch.DefaultUpgradeURL)
m.polling = true
return m, upgradeTickCmd()
}
}
case upgradeTickMsg:
if !m.polling {
return m, nil
}
m.spinner++
if m.spinner%5 == 0 {
return m, tea.Batch(
upgradeTickCmd(),
checkUpgrade(m.requiredPlan),
)
}
return m, upgradeTickCmd()
case upgradeCheckMsg:
if msg.err != nil {
m.err = msg.err
return m, tea.Quit
}
if msg.upgraded {
m.plan = msg.plan
return m, tea.Quit
}
}
return m, nil
}
func (m upgradeModel) View() string {
if m.plan != "" {
return ""
}
if m.err != nil {
return ""
}
return renderUpgrade(m.modelName, m.spinner, m.width, m.polling, m.openNow)
}
func renderSignIn(modelName, signInURL string, spinner, width int) string {
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
frame := spinnerFrames[spinner%len(spinnerFrames)]
urlColor := lipgloss.NewStyle().
Foreground(lipgloss.Color("117"))
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
if width > 4 {
urlWrap = urlWrap.Width(width - 4)
}
var s strings.Builder
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
s.WriteString("\n\n")
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
frame + " Waiting for sign in to complete..."))
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("esc cancel"))
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
func upgradeTickCmd() tea.Cmd {
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
return upgradeTickMsg{}
})
}
func renderUpgrade(modelName string, spinner, width int, polling, openNow bool) string {
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
frame := spinnerFrames[spinner%len(spinnerFrames)]
urlColor := lipgloss.NewStyle().
Foreground(lipgloss.Color("117"))
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
if width > 4 {
urlWrap = urlWrap.Width(width - 4)
}
var s strings.Builder
fmt.Fprintf(&s, "To use %s, upgrade your Ollama plan.\n\n", selectorSelectedItemStyle.Render(modelName))
s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(urlColor.Render(launch.DefaultUpgradeURL)))
s.WriteString("\n\n")
if !polling {
var yesBtn, noBtn string
if openNow {
yesBtn = confirmActiveStyle.Render(" Yes ")
noBtn = confirmInactiveStyle.Render(" No ")
} else {
yesBtn = confirmInactiveStyle.Render(" Yes ")
noBtn = confirmActiveStyle.Render(" No ")
}
s.WriteString("Open now?\n")
s.WriteString(" " + yesBtn + " " + noBtn)
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel"))
} else {
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
frame + " Waiting for upgrade to complete..."))
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("esc cancel"))
}
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
}
func checkSignIn() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return signInCheckMsg{signedIn: false}
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
user, err := client.Whoami(ctx)
if err == nil && user != nil && user.Name != "" {
return signInCheckMsg{signedIn: true, userName: user.Name}
}
return signInCheckMsg{signedIn: false}
}
func checkUpgrade(requiredPlan string) tea.Cmd {
return func() tea.Msg {
client, err := api.ClientFromEnvironment()
if err != nil {
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
user, err := client.Whoami(ctx)
if err != nil {
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
}
if err == nil && user != nil && user.Name != "" && launch.PlanSatisfies(user.Plan, requiredPlan) {
return upgradeCheckMsg{upgraded: true, plan: user.Plan}
}
return upgradeCheckMsg{upgraded: false}
}
}
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
func RunSignIn(modelName, signInURL string) (string, error) {
launch.OpenBrowser(signInURL)
m := signInModel{
modelName: modelName,
signInURL: signInURL,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running sign-in: %w", err)
}
fm := finalModel.(signInModel)
if fm.cancelled {
return "", ErrCancelled
}
return fm.userName, nil
}
// RunUpgrade shows a bubbletea upgrade dialog and polls until the user's plan is updated or cancelled.
func RunUpgrade(modelName, requiredPlan string) (string, error) {
m := upgradeModel{
modelName: modelName,
requiredPlan: requiredPlan,
openNow: true,
}
p := tea.NewProgram(m)
finalModel, err := p.Run()
if err != nil {
return "", fmt.Errorf("error running upgrade: %w", err)
}
fm := finalModel.(upgradeModel)
if fm.cancelled {
return "", ErrCancelled
}
if fm.err != nil {
return "", fm.err
}
return fm.plan, nil
}

217
cmd/tui/signin_test.go Normal file
View File

@@ -0,0 +1,217 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/ollama/ollama/cmd/launch"
)
func TestRenderSignIn_ContainsModelName(t *testing.T) {
got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
if !strings.Contains(got, "glm-4.7:cloud") {
t.Error("should contain model name")
}
if !strings.Contains(got, "please sign in") {
t.Error("should contain sign-in prompt")
}
}
func TestRenderSignIn_ContainsURL(t *testing.T) {
url := "https://ollama.com/connect?key=abc123"
got := renderSignIn("test:cloud", url, 0, 120)
if !strings.Contains(got, url) {
t.Errorf("should contain URL %q", url)
}
}
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
if !strings.Contains(got, "Waiting for sign in to complete") {
t.Error("should contain waiting message")
}
if !strings.Contains(got, "⠋") {
t.Error("should contain first spinner frame at spinner=0")
}
}
func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
if got0 == got1 {
t.Error("different spinner values should produce different output")
}
}
func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
if !strings.Contains(got, "esc cancel") {
t.Error("should contain esc cancel help text")
}
}
func TestRenderUpgrade_AsksBeforeOpening(t *testing.T) {
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, false, true)
if !strings.Contains(got, "kimi-k2.6:cloud") {
t.Error("should contain model name")
}
if !strings.Contains(got, launch.DefaultUpgradeURL) {
t.Error("should contain upgrade URL")
}
if !strings.Contains(got, "Open now?") {
t.Error("should ask before opening")
}
if !strings.Contains(got, "Yes") || !strings.Contains(got, "No") {
t.Error("should show yes/no selector")
}
if strings.Contains(got, "Waiting for upgrade to complete") {
t.Error("should not start waiting before open choice is confirmed")
}
}
func TestRenderUpgrade_PollingShowsWaiting(t *testing.T) {
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, true, true)
if !strings.Contains(got, "Waiting for upgrade to complete") {
t.Error("should contain waiting message")
}
if strings.Contains(got, "Open now?") {
t.Error("should not show open prompt while polling")
}
}
func TestSignInModel_EscCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
fm := updated.(signInModel)
if !fm.cancelled {
t.Error("esc should set cancelled=true")
}
if cmd == nil {
t.Error("esc should return tea.Quit")
}
}
func TestUpgradeModel_NoCancelsWithoutPolling(t *testing.T) {
m := upgradeModel{
modelName: "kimi-k2.6:cloud",
requiredPlan: "pro",
openNow: true,
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
fm := updated.(upgradeModel)
if fm.openNow {
t.Error("right should select no")
}
if fm.polling {
t.Error("right should not start polling")
}
updated, cmd := fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm = updated.(upgradeModel)
if !fm.cancelled {
t.Error("enter on no should cancel")
}
if fm.polling {
t.Error("enter on no should not start polling")
}
if cmd == nil {
t.Error("enter on no should quit")
}
}
func TestSignInModel_CtrlCCancels(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
fm := updated.(signInModel)
if !fm.cancelled {
t.Error("ctrl+c should set cancelled=true")
}
if cmd == nil {
t.Error("ctrl+c should return tea.Quit")
}
}
func TestSignInModel_SignedInQuitsClean(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
fm := updated.(signInModel)
if fm.userName != "alice" {
t.Errorf("expected userName 'alice', got %q", fm.userName)
}
if cmd == nil {
t.Error("successful sign-in should return tea.Quit")
}
}
func TestSignInModel_SignedInViewClears(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
userName: "alice",
}
got := m.View()
if got != "" {
t.Errorf("View should return empty string after sign-in, got %q", got)
}
}
func TestSignInModel_NotSignedInContinues(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, _ := m.Update(signInCheckMsg{signedIn: false})
fm := updated.(signInModel)
if fm.userName != "" {
t.Error("should not set userName when not signed in")
}
if fm.cancelled {
t.Error("should not cancel when check returns not signed in")
}
}
func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
}
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
fm := updated.(signInModel)
if fm.width != 120 {
t.Errorf("expected width 120, got %d", fm.width)
}
}
func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
m := signInModel{
modelName: "test:cloud",
signInURL: "https://example.com",
spinner: 0,
}
updated, cmd := m.Update(signInTickMsg{})
fm := updated.(signInModel)
if fm.spinner != 1 {
t.Errorf("expected spinner=1, got %d", fm.spinner)
}
if cmd == nil {
t.Error("tick should return a command")
}
}

393
cmd/tui/tui.go Normal file
View File

@@ -0,0 +1,393 @@
package tui
import (
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/cmd/launch"
"github.com/ollama/ollama/version"
)
var (
versionStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
menuItemStyle = lipgloss.NewStyle().
PaddingLeft(2)
menuSelectedItemStyle = lipgloss.NewStyle().
Bold(true).
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
menuDescStyle = selectorDescStyle.
PaddingLeft(4)
greyedStyle = menuItemStyle.
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
greyedSelectedStyle = menuSelectedItemStyle.
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
modelStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
notInstalledStyle = lipgloss.NewStyle().
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
Italic(true)
)
type menuItem struct {
title string
description string
integration string
isRunModel bool
isOthers bool
}
const pinnedIntegrationCount = 4
var runModelMenuItem = menuItem{
title: "Chat with a model",
description: "Start an interactive chat with a model",
isRunModel: true,
}
var othersMenuItem = menuItem{
title: "More...",
description: "Show additional integrations",
isOthers: true,
}
type model struct {
state *launch.LauncherState
items []menuItem
cursor int
showOthers bool
width int
quitting bool
selected bool
action TUIAction
}
func newModel(state *launch.LauncherState) model {
m := model{
state: state,
}
m.showOthers = shouldExpandOthers(state)
m.items = buildMenuItems(state, m.showOthers)
m.cursor = initialCursor(state, m.items)
return m
}
func shouldExpandOthers(state *launch.LauncherState) bool {
if state == nil {
return false
}
for _, item := range otherIntegrationItems(state) {
if item.integration == state.LastSelection {
return true
}
}
return false
}
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
items := []menuItem{runModelMenuItem}
items = append(items, pinnedIntegrationItems(state)...)
otherItems := otherIntegrationItems(state)
switch {
case showOthers:
items = append(items, otherItems...)
case len(otherItems) > 0:
items = append(items, othersMenuItem)
}
return items
}
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
description := state.Description
if description == "" {
description = "Open " + state.DisplayName + " integration"
}
return menuItem{
title: "Launch " + state.DisplayName,
description: description,
integration: state.Name,
}
}
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
ordered := orderedIntegrationItems(state)
if len(ordered) <= pinnedIntegrationCount {
return nil
}
return ordered[pinnedIntegrationCount:]
}
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
ordered := orderedIntegrationItems(state)
if len(ordered) <= pinnedIntegrationCount {
return ordered
}
return ordered[:pinnedIntegrationCount]
}
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
if state == nil {
return nil
}
items := make([]menuItem, 0, len(state.Integrations))
for _, info := range launch.ListIntegrationInfos() {
integrationState, ok := state.Integrations[info.Name]
if !ok {
continue
}
items = append(items, integrationMenuItem(integrationState))
}
return items
}
func primaryMenuItemCount(state *launch.LauncherState) int {
return 1 + len(pinnedIntegrationItems(state))
}
func initialCursor(state *launch.LauncherState, items []menuItem) int {
if state == nil || state.LastSelection == "" {
return 0
}
for i, item := range items {
if state.LastSelection == "run" && item.isRunModel {
return i
}
if item.integration == state.LastSelection {
return i
}
}
return 0
}
func (m model) Init() tea.Cmd {
return nil
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q", "esc":
m.quitting = true
return m, tea.Quit
case "up", "k":
if m.cursor > 0 {
m.cursor--
}
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
m.showOthers = false
m.items = buildMenuItems(m.state, false)
m.cursor = min(m.cursor, len(m.items)-1)
}
return m, nil
case "down", "j":
if m.cursor < len(m.items)-1 {
m.cursor++
}
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
m.showOthers = true
m.items = buildMenuItems(m.state, true)
}
return m, nil
case "enter", " ":
if m.selectableItem(m.items[m.cursor]) {
m.selected = true
m.action = actionForMenuItem(m.items[m.cursor], false)
m.quitting = true
return m, tea.Quit
}
return m, nil
case "right", "l":
item := m.items[m.cursor]
if item.isRunModel || m.changeableItem(item) {
m.selected = true
m.action = actionForMenuItem(item, true)
m.quitting = true
return m, tea.Quit
}
return m, nil
}
}
return m, nil
}
func (m model) selectableItem(item menuItem) bool {
if item.isRunModel {
return true
}
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Selectable
}
func (m model) changeableItem(item menuItem) bool {
if item.integration == "" || item.isOthers {
return false
}
state, ok := m.state.Integrations[item.integration]
return ok && state.Changeable
}
func (m model) View() string {
if m.quitting {
return ""
}
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
for i, item := range m.items {
s += m.renderMenuItem(i, item)
}
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
if m.width > 0 {
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
}
return s
}
func (m model) renderMenuItem(index int, item menuItem) string {
cursor := ""
style := menuItemStyle
title := item.title
description := item.description
modelSuffix := ""
if m.cursor == index {
cursor = "▸ "
}
if item.isRunModel {
if m.cursor == index && m.state.RunModel != "" {
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
}
if m.cursor == index {
style = menuSelectedItemStyle
}
} else if item.isOthers {
if m.cursor == index {
style = menuSelectedItemStyle
}
} else {
integrationState := m.state.Integrations[item.integration]
if !integrationState.Selectable {
if m.cursor == index {
style = greyedSelectedStyle
} else {
style = greyedStyle
}
} else if m.cursor == index {
style = menuSelectedItemStyle
}
if m.cursor == index && integrationState.CurrentModel != "" {
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
}
if !integrationState.Installed {
if integrationState.AutoInstallable {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
if m.cursor == index {
if integrationState.AutoInstallable {
description = "Press enter to install"
} else if integrationState.InstallHint != "" {
description = integrationState.InstallHint
} else {
description = "not installed"
}
}
}
}
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
}
type TUIActionKind int
const (
TUIActionNone TUIActionKind = iota
TUIActionRunModel
TUIActionLaunchIntegration
)
type TUIAction struct {
Kind TUIActionKind
Integration string
ForceConfigure bool
}
func (a TUIAction) LastSelection() string {
switch a.Kind {
case TUIActionRunModel:
return "run"
case TUIActionLaunchIntegration:
return a.Integration
default:
return ""
}
}
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
}
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
return launch.IntegrationLaunchRequest{
Name: a.Integration,
ForceConfigure: a.ForceConfigure,
}
}
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
switch {
case item.isRunModel:
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
case item.integration != "":
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
default:
return TUIAction{Kind: TUIActionNone}
}
}
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
menu := newModel(state)
program := tea.NewProgram(menu)
finalModel, err := program.Run()
if err != nil {
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
}
finalMenu := finalModel.(model)
if !finalMenu.selected {
return TUIAction{Kind: TUIActionNone}, nil
}
return finalMenu.action, nil
}

296
cmd/tui/tui_test.go Normal file
View File

@@ -0,0 +1,296 @@
package tui
import (
"strings"
"testing"
tea "github.com/charmbracelet/bubbletea"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/launch"
)
func launcherTestState() *launch.LauncherState {
return &launch.LauncherState{
LastSelection: "run",
RunModel: "qwen3:8b",
Integrations: map[string]launch.LauncherIntegrationState{
"claude": {
Name: "claude",
DisplayName: "Claude Code",
Description: "Anthropic's coding tool with subagents",
Selectable: true,
Changeable: true,
CurrentModel: "glm-5:cloud",
},
"codex": {
Name: "codex",
DisplayName: "Codex",
Description: "OpenAI's open-source coding agent",
Selectable: true,
Changeable: true,
},
"codex-app": {
Name: "codex-app",
DisplayName: "Codex App",
Description: "An AI agent you can delegate real work to, by OpenAI",
Selectable: true,
Changeable: true,
},
"openclaw": {
Name: "openclaw",
DisplayName: "OpenClaw",
Description: "Personal AI with 100+ skills",
Selectable: true,
Changeable: true,
AutoInstallable: true,
},
"opencode": {
Name: "opencode",
DisplayName: "OpenCode",
Description: "Anomaly's open-source coding agent",
Selectable: true,
Changeable: true,
},
"hermes": {
Name: "hermes",
DisplayName: "Hermes Agent",
Description: "Self-improving AI agent built by Nous Research",
Selectable: true,
Changeable: true,
},
"droid": {
Name: "droid",
DisplayName: "Droid",
Description: "Factory's coding agent across terminal and IDEs",
Selectable: true,
Changeable: true,
},
"pi": {
Name: "pi",
DisplayName: "Pi",
Description: "Minimal AI agent toolkit with plugin support",
Selectable: true,
Changeable: true,
},
},
}
}
func findMenuCursorByIntegration(items []menuItem, name string) int {
for i, item := range items {
if item.integration == name {
return i
}
}
return -1
}
func integrationSequence(items []menuItem) []string {
sequence := make([]string, 0, len(items))
for _, item := range items {
switch {
case item.isRunModel:
sequence = append(sequence, "run")
case item.isOthers:
sequence = append(sequence, "more")
case item.integration != "":
sequence = append(sequence, item.integration)
}
}
return sequence
}
func compareStrings(got, want []string) string {
return cmp.Diff(want, got)
}
func expectedCollapsedSequence(state *launch.LauncherState) []string {
sequence := []string{"run"}
for _, item := range pinnedIntegrationItems(state) {
sequence = append(sequence, item.integration)
}
if len(otherIntegrationItems(state)) > 0 {
sequence = append(sequence, "more")
}
return sequence
}
func expectedExpandedSequence(state *launch.LauncherState) []string {
sequence := []string{"run"}
for _, item := range pinnedIntegrationItems(state) {
sequence = append(sequence, item.integration)
}
for _, item := range otherIntegrationItems(state) {
sequence = append(sequence, item.integration)
}
return sequence
}
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
state := launcherTestState()
menu := newModel(state)
wantPrefix := []string{"run", "claude", "codex-app", "hermes", "openclaw"}
if findMenuCursorByIntegration(menu.items, "codex-app") == -1 {
wantPrefix = []string{"run", "claude", "hermes", "openclaw", "opencode"}
}
if got := integrationSequence(menu.items); len(got) < len(wantPrefix) {
t.Fatalf("expected at least %d menu items, got %v", len(wantPrefix), got)
} else if diff := compareStrings(got[:len(wantPrefix)], wantPrefix); diff != "" {
t.Fatalf("unexpected primary TUI order: %s", diff)
}
view := menu.View()
for _, want := range []string{"Chat with a model", "Launch Claude Code", "Launch Hermes Agent", "Launch OpenClaw", "More..."} {
if !strings.Contains(view, want) {
t.Fatalf("expected menu view to contain %q\n%s", want, view)
}
}
if findMenuCursorByIntegration(menu.items, "codex-app") != -1 && !strings.Contains(view, "Launch Codex App") {
t.Fatalf("expected menu view to contain Codex App\n%s", view)
}
if strings.Contains(view, "Launch Claude Desktop") {
t.Fatalf("expected hidden Claude Desktop to be absent\n%s", view)
}
wantOrder := expectedCollapsedSequence(state)
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected pinned order: %s", diff)
}
}
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
state := launcherTestState()
overflow := otherIntegrationItems(state)
if len(overflow) == 0 {
t.Fatal("expected at least one overflow integration")
}
state.LastSelection = overflow[0].integration
menu := newModel(state)
if !menu.showOthers {
t.Fatal("expected others section to expand when last selection is in the overflow list")
}
view := menu.View()
if !strings.Contains(view, overflow[0].title) {
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
}
if strings.Contains(view, "More...") {
t.Fatalf("expected expanded view to replace More... item\n%s", view)
}
wantOrder := expectedExpandedSequence(state)
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected expanded order: %s", diff)
}
}
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel}
if !got.selected || got.action != want {
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
menu := newModel(launcherTestState())
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
if menu.cursor == -1 {
t.Fatal("expected claude menu item")
}
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
if !got.selected || got.action != want {
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
menu := newModel(launcherTestState())
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
if menu.cursor == -1 {
t.Fatal("expected claude menu item")
}
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
got := updated.(model)
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
if !got.selected || got.action != want {
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
}
}
func TestMenuIgnoresDisabledActions(t *testing.T) {
state := launcherTestState()
claude := state.Integrations["claude"]
claude.Selectable = false
claude.Changeable = false
state.Integrations["claude"] = claude
menu := newModel(state)
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
if menu.cursor == -1 {
t.Fatal("expected claude menu item")
}
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
if updatedEnter.(model).selected {
t.Fatal("expected non-selectable integration to ignore enter")
}
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
if updatedRight.(model).selected {
t.Fatal("expected non-changeable integration to ignore right")
}
}
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
menu := newModel(launcherTestState())
runView := menu.View()
if !strings.Contains(runView, "(qwen3:8b)") {
t.Fatalf("expected run row to show current model suffix\n%s", runView)
}
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
if menu.cursor == -1 {
t.Fatal("expected claude menu item")
}
integrationView := menu.View()
if !strings.Contains(integrationView, "(glm-5:cloud)") {
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
}
}
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
state := launcherTestState()
codex := state.Integrations["codex"]
codex.Installed = false
codex.Selectable = false
codex.Changeable = false
codex.InstallHint = "Install from https://example.com/codex"
state.Integrations["codex"] = codex
state.LastSelection = "codex"
menu := newModel(state)
menu.cursor = findMenuCursorByIntegration(menu.items, "codex")
if menu.cursor == -1 {
t.Fatal("expected codex menu item in overflow section")
}
view := menu.View()
if !strings.Contains(view, "(not installed)") {
t.Fatalf("expected not-installed marker\n%s", view)
}
if !strings.Contains(view, codex.InstallHint) {
t.Fatalf("expected install hint in description\n%s", view)
}
}

63
cmd/warn_thinking_test.go Normal file
View File

@@ -0,0 +1,63 @@
package cmd
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
ensureThinkingSupport(t.Context(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}