ollama source for Momentry Core verification
This commit is contained in:
13
cmd/background_unix.go
Normal file
13
cmd/background_unix.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
|
||||
// Setpgid prevents the server from being killed when the parent process exits.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
12
cmd/background_windows.go
Normal file
12
cmd/background_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import "syscall"
|
||||
|
||||
// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
|
||||
// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
|
||||
func backgroundServerSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: 0x08000000,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
143
cmd/bench/README.md
Normal file
143
cmd/bench/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
Ollama Benchmark Tool
|
||||
---------------------
|
||||
|
||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output.
|
||||
|
||||
## Features
|
||||
|
||||
* Benchmark multiple models in a single run
|
||||
* Support for both text and image prompts
|
||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||
* Warmup phase before timed epochs to stabilize measurements
|
||||
* Time-to-first-token (TTFT) tracking per epoch
|
||||
* Model metadata display (parameter size, quantization level, family)
|
||||
* VRAM and CPU memory usage tracking via running process info
|
||||
* Controlled prompt token length for reproducible benchmarks
|
||||
* Benchstat and CSV output formats
|
||||
|
||||
## Building from Source
|
||||
|
||||
```
|
||||
go build -o ollama-bench ./cmd/bench
|
||||
./ollama-bench -model gemma3 -epochs 6 -format csv
|
||||
```
|
||||
|
||||
Using Go Run (without building)
|
||||
|
||||
```
|
||||
go run ./cmd/bench -model gemma3 -epochs 3
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6
|
||||
```
|
||||
|
||||
### Benchmark Multiple Models
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||
benchstat -col /name gemma.bench
|
||||
```
|
||||
|
||||
### With Image Prompt
|
||||
|
||||
```
|
||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||
```
|
||||
|
||||
### Controlled Prompt Length
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512
|
||||
```
|
||||
|
||||
### Advanced Example
|
||||
|
||||
```
|
||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
|
||||
| Option | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| -model | Comma-separated list of models to benchmark | (required) |
|
||||
| -epochs | Number of iterations per model | 6 |
|
||||
| -max-tokens | Maximum tokens for model response | 200 |
|
||||
| -temperature | Temperature parameter | 0.0 |
|
||||
| -seed | Random seed | 0 (random) |
|
||||
| -timeout | Timeout in seconds | 300 |
|
||||
| -p | Prompt text | (default story prompt) |
|
||||
| -image | Image file to include in prompt | |
|
||||
| -k | Keep-alive duration in seconds | 0 |
|
||||
| -format | Output format (benchstat, csv) | benchstat |
|
||||
| -output | Output file for results | "" (stdout) |
|
||||
| -warmup | Number of warmup requests before timing | 1 |
|
||||
| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 |
|
||||
| -v | Verbose mode | false |
|
||||
| -debug | Show debug information | false |
|
||||
|
||||
## Output Formats
|
||||
|
||||
### Benchstat Format (default)
|
||||
|
||||
Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics.
|
||||
|
||||
```
|
||||
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||
BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec
|
||||
BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec
|
||||
BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op
|
||||
BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op
|
||||
BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op
|
||||
```
|
||||
|
||||
Use with benchstat:
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 > gemma3.bench
|
||||
benchstat -col /step gemma3.bench
|
||||
```
|
||||
|
||||
Compare two runs:
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6 > before.bench
|
||||
# ... make changes ...
|
||||
./ollama-bench -model gemma3 -epochs 6 > after.bench
|
||||
benchstat before.bench after.bench
|
||||
```
|
||||
|
||||
### CSV Format
|
||||
|
||||
Machine-readable comma-separated values:
|
||||
|
||||
```
|
||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||
# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931
|
||||
gemma3,prefill,128,78125.00,12800.00
|
||||
gemma3,generate,512,19531.25,51200.00
|
||||
gemma3,ttft,1,45123000,0
|
||||
gemma3,load,1,1500000000,0
|
||||
gemma3,total,1,2861047625,0
|
||||
```
|
||||
|
||||
## Metrics Explained
|
||||
|
||||
The tool reports the following metrics for each epoch:
|
||||
|
||||
* **prefill**: Time spent processing the prompt (ns/token)
|
||||
* **generate**: Time spent generating the response (ns/token)
|
||||
* **ttft**: Time to first token -- latency from request start to first response content
|
||||
* **load**: Model loading time (one-time cost)
|
||||
* **total**: Total request duration
|
||||
|
||||
Additionally, the model info comment line (displayed once per model before epochs) includes:
|
||||
|
||||
* **Params**: Model parameter count (e.g., 4.3B)
|
||||
* **Quant**: Quantization level (e.g., Q4_K_M)
|
||||
* **Family**: Model family (e.g., gemma3)
|
||||
* **Size**: Total model memory in bytes
|
||||
* **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill)
|
||||
561
cmd/bench/bench.go
Normal file
561
cmd/bench/bench.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type flagOptions struct {
|
||||
models *string
|
||||
epochs *int
|
||||
maxTokens *int
|
||||
temperature *float64
|
||||
seed *int
|
||||
timeout *int
|
||||
prompt *string
|
||||
imageFile *string
|
||||
keepAlive *float64
|
||||
format *string
|
||||
outputFile *string
|
||||
debug *bool
|
||||
verbose *bool
|
||||
warmup *int
|
||||
promptTokens *int
|
||||
numCtx *int
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
Model string
|
||||
Step string
|
||||
Count int
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
type ModelInfo struct {
|
||||
Name string
|
||||
ParameterSize string
|
||||
QuantizationLevel string
|
||||
Family string
|
||||
SizeBytes int64
|
||||
VRAMBytes int64
|
||||
NumCtx int64
|
||||
}
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
|
||||
// Word list for generating prompts targeting a specific token count.
|
||||
var promptWordList = []string{
|
||||
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||
"a", "bright", "sunny", "day", "in", "the", "meadow", "where",
|
||||
"flowers", "bloom", "and", "birds", "sing", "their", "morning",
|
||||
"songs", "while", "gentle", "breeze", "carries", "sweet", "scent",
|
||||
"of", "pine", "trees", "across", "rolling", "hills", "toward",
|
||||
"distant", "mountains", "covered", "with", "fresh", "snow",
|
||||
"beneath", "clear", "blue", "sky", "children", "play", "near",
|
||||
"old", "stone", "bridge", "that", "crosses", "winding", "river",
|
||||
}
|
||||
|
||||
// tokensPerWord is the calibrated ratio of tokens to words for the current model.
|
||||
// Initialized with a heuristic, then updated during warmup based on actual tokenization.
|
||||
var tokensPerWord = 1.3
|
||||
|
||||
func generatePromptForTokenCount(targetTokens int, epoch int) string {
|
||||
targetWords := int(float64(targetTokens) / tokensPerWord)
|
||||
if targetWords < 1 {
|
||||
targetWords = 1
|
||||
}
|
||||
|
||||
// Vary the starting offset by epoch to defeat KV cache prefix matching
|
||||
offset := epoch * 7 // stride by a prime to get good distribution
|
||||
n := len(promptWordList)
|
||||
words := make([]string, targetWords)
|
||||
for i := range words {
|
||||
words[i] = promptWordList[((i+offset)%n+n)%n]
|
||||
}
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// calibratePromptTokens adjusts tokensPerWord based on actual tokenization from a warmup run.
|
||||
func calibratePromptTokens(targetTokens, actualTokens, wordCount int) {
|
||||
if actualTokens <= 0 || wordCount <= 0 {
|
||||
return
|
||||
}
|
||||
tokensPerWord = float64(actualTokens) / float64(wordCount)
|
||||
newWords := int(float64(targetTokens) / tokensPerWord)
|
||||
fmt.Fprintf(os.Stderr, "bench: calibrated %.2f tokens/word (target=%d, got=%d, words=%d → %d)\n",
|
||||
tokensPerWord, targetTokens, actualTokens, wordCount, newWords)
|
||||
}
|
||||
|
||||
func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
options["num_ctx"] = *fOpt.numCtx
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||
keepAliveDuration = &duration
|
||||
}
|
||||
|
||||
prompt := *fOpt.prompt
|
||||
if *fOpt.promptTokens > 0 {
|
||||
prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch)
|
||||
} else {
|
||||
// Vary the prompt per epoch to defeat KV cache prefix matching
|
||||
prompt = fmt.Sprintf("[%d] %s", epoch, prompt)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Raw: true,
|
||||
Options: options,
|
||||
KeepAlive: keepAliveDuration,
|
||||
}
|
||||
|
||||
if imgData != nil {
|
||||
req.Images = []api.ImageData{imgData}
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo {
|
||||
info := ModelInfo{Name: model}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err)
|
||||
return info
|
||||
}
|
||||
info.ParameterSize = resp.Details.ParameterSize
|
||||
info.QuantizationLevel = resp.Details.QuantizationLevel
|
||||
info.Family = resp.Details.Family
|
||||
return info
|
||||
}
|
||||
|
||||
func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err)
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model {
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return m.Size, m.SizeVRAM
|
||||
}
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func fetchContextLength(ctx context.Context, client *api.Client, model string) int64 {
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model || m.Model == model || strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) {
|
||||
return int64(m.ContextLength)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func outputFormatHeader(w io.Writer, format string, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
fmt.Fprintf(w, "goos: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH)
|
||||
}
|
||||
case "csv":
|
||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||
}
|
||||
}
|
||||
|
||||
func outputModelInfo(w io.Writer, format string, info ModelInfo) {
|
||||
params := cmp.Or(info.ParameterSize, "unknown")
|
||||
quant := cmp.Or(info.QuantizationLevel, "unknown")
|
||||
family := cmp.Or(info.Family, "unknown")
|
||||
|
||||
memStr := ""
|
||||
if info.SizeBytes > 0 {
|
||||
memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes)
|
||||
}
|
||||
ctxStr := ""
|
||||
if info.NumCtx > 0 {
|
||||
ctxStr = fmt.Sprintf(" | NumCtx: %d", info.NumCtx)
|
||||
}
|
||||
fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s%s\n",
|
||||
info.Name, params, quant, family, memStr, ctxStr)
|
||||
}
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
if m.Count > 0 {
|
||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n",
|
||||
m.Model, m.Step, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n",
|
||||
m.Model, m.Step)
|
||||
}
|
||||
} else if m.Step == "ttft" {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n",
|
||||
m.Model, m.Duration.Nanoseconds())
|
||||
} else {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n",
|
||||
m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "csv":
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
var nsPerToken float64
|
||||
var tokensPerSec float64
|
||||
if m.Count > 0 {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
}
|
||||
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkModel(fOpt flagOptions) error {
|
||||
models := strings.Split(*fOpt.models, ",")
|
||||
|
||||
var imgData api.ImageData
|
||||
var err error
|
||||
if *fOpt.imageFile != "" {
|
||||
imgData, err = readImage(*fOpt.imageFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if *fOpt.debug && imgData != nil {
|
||||
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var out io.Writer = os.Stdout
|
||||
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
|
||||
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
out = f
|
||||
}
|
||||
|
||||
outputFormatHeader(out, *fOpt.format, *fOpt.verbose)
|
||||
|
||||
// Log prompt-tokens info in debug mode
|
||||
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0)
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
// Fetch model info
|
||||
infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
info := fetchModelInfo(infoCtx, client, model)
|
||||
infoCancel()
|
||||
|
||||
// Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs)
|
||||
for i := range *fOpt.warmup {
|
||||
req := buildGenerateRequest(model, fOpt, imgData, -(i + 1))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
|
||||
var warmupMetrics *api.Metrics
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
warmupMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err)
|
||||
} else {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model)
|
||||
}
|
||||
// Calibrate prompt token count on last warmup run
|
||||
if i == *fOpt.warmup-1 && *fOpt.promptTokens > 0 && warmupMetrics != nil {
|
||||
prompt := generatePromptForTokenCount(*fOpt.promptTokens, -(i + 1))
|
||||
wordCount := len(strings.Fields(prompt))
|
||||
calibratePromptTokens(*fOpt.promptTokens, warmupMetrics.PromptEvalCount, wordCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch memory/context info once after warmup (model is loaded and stable)
|
||||
memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model)
|
||||
if fOpt.numCtx != nil && *fOpt.numCtx > 0 {
|
||||
info.NumCtx = int64(*fOpt.numCtx)
|
||||
} else {
|
||||
info.NumCtx = fetchContextLength(memCtx, client, model)
|
||||
}
|
||||
memCancel()
|
||||
|
||||
outputModelInfo(out, *fOpt.format, info)
|
||||
|
||||
// Timed epoch loop
|
||||
shortCount := 0
|
||||
for epoch := range *fOpt.epochs {
|
||||
var responseMetrics *api.Metrics
|
||||
var ttft time.Duration
|
||||
short := false
|
||||
|
||||
// Retry loop: if the model hits a stop token before max-tokens,
|
||||
// retry with a different prompt (up to maxRetries times).
|
||||
const maxRetries = 3
|
||||
for attempt := range maxRetries + 1 {
|
||||
responseMetrics = nil
|
||||
ttft = 0
|
||||
var ttftOnce sync.Once
|
||||
|
||||
req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000)
|
||||
requestStart := time.Now()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
|
||||
err = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response))
|
||||
}
|
||||
|
||||
// Capture TTFT on first content
|
||||
ttftOnce.Do(func() {
|
||||
if resp.Response != "" || resp.Thinking != "" {
|
||||
ttft = time.Since(requestStart)
|
||||
}
|
||||
})
|
||||
|
||||
if resp.Done {
|
||||
responseMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
cancel()
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if responseMetrics == nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||
break
|
||||
}
|
||||
|
||||
// Check if the response was shorter than requested
|
||||
short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens
|
||||
if !short || attempt == maxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n",
|
||||
responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil || responseMetrics == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if short {
|
||||
shortCount++
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n",
|
||||
responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1)
|
||||
}
|
||||
}
|
||||
|
||||
metrics := []Metrics{
|
||||
{
|
||||
Model: model,
|
||||
Step: "prefill",
|
||||
Count: responseMetrics.PromptEvalCount,
|
||||
Duration: responseMetrics.PromptEvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "generate",
|
||||
Count: responseMetrics.EvalCount,
|
||||
Duration: responseMetrics.EvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "ttft",
|
||||
Count: 1,
|
||||
Duration: ttft,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "load",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.LoadDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "total",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.TotalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||
|
||||
if *fOpt.debug && *fOpt.promptTokens > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n",
|
||||
*fOpt.promptTokens, responseMetrics.PromptEvalCount)
|
||||
}
|
||||
|
||||
if *fOpt.keepAlive > 0 {
|
||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
if shortCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n",
|
||||
shortCount, *fOpt.epochs, model, *fOpt.maxTokens)
|
||||
}
|
||||
|
||||
// Unload model before moving to the next one
|
||||
unloadModel(client, model, *fOpt.timeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unloadModel(client *api.Client, model string, timeout int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
zero := api.Duration{Duration: 0}
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &zero,
|
||||
}
|
||||
_ = client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func readImage(filePath string) (api.ImageData, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return api.ImageData(data), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
fOpt := flagOptions{
|
||||
models: flag.String("model", "", "Model to benchmark"),
|
||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||
seed: flag.Int("seed", 0, "Random seed"),
|
||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||
format: flag.String("format", "benchstat", "Output format [benchstat|csv]"),
|
||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||
verbose: flag.Bool("v", false, "Show system information"),
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"),
|
||||
promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"),
|
||||
numCtx: flag.Int("num-ctx", 0, "Context size (0 = server default)"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Description:\n")
|
||||
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n")
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(*fOpt.models) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
BenchmarkModel(fOpt)
|
||||
}
|
||||
1410
cmd/bench/bench_test.go
Normal file
1410
cmd/bench/bench_test.go
Normal file
File diff suppressed because it is too large
Load Diff
2583
cmd/cmd.go
Normal file
2583
cmd/cmd.go
Normal file
File diff suppressed because it is too large
Load Diff
305
cmd/cmd_launcher_test.go
Normal file
305
cmd/cmd_launcher_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/cmd/tui"
|
||||
)
|
||||
|
||||
func setCmdTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
func unexpectedRunModelResolution(t *testing.T) func(context.Context, launch.RunModelRequest) (string, error) {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
t.Fatalf("did not expect run-model resolution: %+v", req)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedIntegrationLaunch(t *testing.T) func(context.Context, launch.IntegrationLaunchRequest) error {
|
||||
t.Helper()
|
||||
return func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
t.Fatalf("did not expect integration launch: %+v", req)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func unexpectedModelLaunch(t *testing.T) func(*cobra.Command, string) error {
|
||||
t.Helper()
|
||||
return func(cmd *cobra.Command, model string) error {
|
||||
t.Fatalf("did not expect chat launch: %s", model)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_RunModelActionsUseResolveRunModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "enter uses saved model flow",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel},
|
||||
wantModel: "qwen3:8b",
|
||||
},
|
||||
{
|
||||
name: "right forces picker",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionRunModel, ForceConfigure: true},
|
||||
wantForce: true,
|
||||
wantModel: "glm-5:cloud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.RunModelRequest
|
||||
var launched string
|
||||
prefetchedAccount := &launch.AccountState{}
|
||||
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
if state.AccountState != prefetchedAccount {
|
||||
t.Fatalf("prefetched account state was not piped to menu state")
|
||||
}
|
||||
return runMenu(state)
|
||||
},
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
gotReq = req
|
||||
return tt.wantModel, nil
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: func(cmd *cobra.Command, model string) error {
|
||||
launched = model
|
||||
return nil
|
||||
},
|
||||
accountState: func() *launch.AccountState {
|
||||
return prefetchedAccount
|
||||
},
|
||||
accountStateUpdates: accountUpdates,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.ForcePicker != tt.wantForce {
|
||||
t.Fatalf("expected ForcePicker=%v, got %v", tt.wantForce, gotReq.ForcePicker)
|
||||
}
|
||||
if gotReq.AccountState != prefetchedAccount {
|
||||
t.Fatalf("expected prefetched account state to be passed to run model request")
|
||||
}
|
||||
if gotReq.AccountStateUpdates == nil {
|
||||
t.Fatalf("expected account state updates to be passed to run model request")
|
||||
}
|
||||
if launched != tt.wantModel {
|
||||
t.Fatalf("expected interactive launcher to run %q, got %q", tt.wantModel, launched)
|
||||
}
|
||||
if got := config.LastSelection(); got != "run" {
|
||||
t.Fatalf("expected last selection to be run, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunInteractiveTUI_IntegrationActionsUseLaunchIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action tui.TUIAction
|
||||
wantForce bool
|
||||
}{
|
||||
{
|
||||
name: "enter launches integration",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"},
|
||||
},
|
||||
{
|
||||
name: "right forces configure",
|
||||
action: tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true},
|
||||
wantForce: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
var menuCalls int
|
||||
runMenu := func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
menuCalls++
|
||||
if menuCalls == 1 {
|
||||
return tt.action, nil
|
||||
}
|
||||
return tui.TUIAction{Kind: tui.TUIActionNone}, nil
|
||||
}
|
||||
|
||||
var gotReq launch.IntegrationLaunchRequest
|
||||
prefetchedAccount := &launch.AccountState{}
|
||||
accountUpdates := func(context.Context) <-chan *launch.AccountState { return nil }
|
||||
deps := launcherDeps{
|
||||
buildState: func(ctx context.Context) (*launch.LauncherState, error) {
|
||||
return &launch.LauncherState{}, nil
|
||||
},
|
||||
runMenu: func(state *launch.LauncherState) (tui.TUIAction, error) {
|
||||
if state.AccountState != prefetchedAccount {
|
||||
t.Fatalf("prefetched account state was not piped to menu state")
|
||||
}
|
||||
return runMenu(state)
|
||||
},
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
gotReq = req
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
accountState: func() *launch.AccountState {
|
||||
return prefetchedAccount
|
||||
},
|
||||
accountStateUpdates: accountUpdates,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
for {
|
||||
continueLoop, err := runInteractiveTUIStep(cmd, deps)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected step error: %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gotReq.Name != "claude" {
|
||||
t.Fatalf("expected integration name to be passed through, got %q", gotReq.Name)
|
||||
}
|
||||
if gotReq.ForceConfigure != tt.wantForce {
|
||||
t.Fatalf("expected ForceConfigure=%v, got %v", tt.wantForce, gotReq.ForceConfigure)
|
||||
}
|
||||
if gotReq.AccountState != prefetchedAccount {
|
||||
t.Fatalf("expected prefetched account state to be passed to integration request")
|
||||
}
|
||||
if gotReq.AccountStateUpdates == nil {
|
||||
t.Fatalf("expected account state updates to be passed to integration request")
|
||||
}
|
||||
if got := config.LastSelection(); got != "claude" {
|
||||
t.Fatalf("expected last selection to be claude, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_RunModelContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionRunModel}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: func(ctx context.Context, req launch.RunModelRequest) (string, error) {
|
||||
return "", launch.ErrCancelled
|
||||
},
|
||||
launchIntegration: unexpectedIntegrationLaunch(t),
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_GUIAppsExitTUILoop(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
for _, integration := range []string{"codex-app", "vscode"} {
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: integration}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error for %s, got %v", integration, err)
|
||||
}
|
||||
if continueLoop {
|
||||
t.Fatalf("expected %s launch to exit the TUI loop (return false)", integration)
|
||||
}
|
||||
}
|
||||
|
||||
// Other integrations should continue the TUI loop (return true).
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return nil
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected non-vscode integration to continue the TUI loop (return true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLauncherAction_IntegrationContinuesAfterCancellation(t *testing.T) {
|
||||
setCmdTestHome(t, t.TempDir())
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.Background())
|
||||
|
||||
continueLoop, err := runLauncherAction(cmd, tui.TUIAction{Kind: tui.TUIActionLaunchIntegration, Integration: "claude"}, launcherDeps{
|
||||
buildState: nil,
|
||||
runMenu: nil,
|
||||
resolveRunModel: unexpectedRunModelResolution(t),
|
||||
launchIntegration: func(ctx context.Context, req launch.IntegrationLaunchRequest) error {
|
||||
return launch.ErrCancelled
|
||||
},
|
||||
runModel: unexpectedModelLaunch(t),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on cancellation, got %v", err)
|
||||
}
|
||||
if !continueLoop {
|
||||
t.Fatal("expected cancellation to continue the menu loop")
|
||||
}
|
||||
}
|
||||
2361
cmd/cmd_test.go
Normal file
2361
cmd/cmd_test.go
Normal file
File diff suppressed because it is too large
Load Diff
284
cmd/config/config.go
Normal file
284
cmd/config/config.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Package config provides integration configuration for external coding tools
|
||||
// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
// IntegrationConfig is the persisted config for one integration.
|
||||
type IntegrationConfig = integration
|
||||
|
||||
type config struct {
|
||||
Integrations map[string]*integration `json:"integrations"`
|
||||
LastModel string `json:"last_model,omitempty"`
|
||||
LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Ignore legacy files with invalid JSON and continue startup.
|
||||
if !json.Valid(oldData) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
cfg.Integrations = make(map[string]*integration)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func save(cfg *config) error {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fileutil.WriteWithBackup(path, data)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
var onboarded bool
|
||||
if existing != nil {
|
||||
aliases = existing.Aliases
|
||||
onboarded = existing.Onboarded
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Onboarded: onboarded,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// MarkIntegrationOnboarded marks an integration as onboarded in Ollama's config.
|
||||
func MarkIntegrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
existing.Onboarded = true
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return integrationConfig.Models[0]
|
||||
}
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := LoadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
return integrationConfig.Models
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
func LastModel() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastModel
|
||||
}
|
||||
|
||||
// SetLastModel saves the last model that was run.
|
||||
func SetLastModel(model string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastModel = model
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
|
||||
func LastSelection() string {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cfg.LastSelection
|
||||
}
|
||||
|
||||
// SetLastSelection saves the last menu selection ("run" or integration name).
|
||||
func SetLastSelection(selection string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.LastSelection = selection
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// LoadIntegration returns the saved config for one integration.
|
||||
func LoadIntegration(appName string) (*integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
// SaveAliases replaces the saved aliases for one integration.
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, integrationConfig := range cfg.Integrations {
|
||||
result = append(result, *integrationConfig)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
641
cmd/config/config_cloud_test.go
Normal file
641
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,641 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := SaveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := SaveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := SaveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := SaveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := SaveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := SaveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := SaveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
|
||||
if serverErr == nil {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error
|
||||
if serverErr != nil {
|
||||
t.Fatal("server should succeed")
|
||||
}
|
||||
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := SaveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = LoadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := SaveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := LoadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
530
cmd/config/config_test.go
Normal file
530
cmd/config/config_test.go
Normal file
@@ -0,0 +1,530 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Setenv("HOME", dir)
|
||||
t.Setenv("TMPDIR", dir)
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
}
|
||||
|
||||
func TestIntegrationConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("save and load round-trip", func(t *testing.T) {
|
||||
models := []string{"llama3.2", "mistral", "qwen2.5"}
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(config.Models) != len(models) {
|
||||
t.Errorf("expected %d models, got %d", len(models), len(config.Models))
|
||||
}
|
||||
for i, m := range models {
|
||||
if config.Models[i] != m {
|
||||
t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := SaveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := SaveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := SaveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
SaveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
config, _ := LoadIntegration("codex")
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-a" {
|
||||
t.Errorf("expected model-a, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns empty for no models", func(t *testing.T) {
|
||||
config := &integration{Models: []string{}}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "" {
|
||||
t.Errorf("expected empty string, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("app name is case-insensitive", func(t *testing.T) {
|
||||
SaveIntegration("Claude", []string{"model-x"})
|
||||
|
||||
config, err := LoadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defaultModel := ""
|
||||
if len(config.Models) > 0 {
|
||||
defaultModel = config.Models[0]
|
||||
}
|
||||
if defaultModel != "model-x" {
|
||||
t.Errorf("expected model-x, got %s", defaultModel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple integrations in single file", func(t *testing.T) {
|
||||
SaveIntegration("app1", []string{"model-1"})
|
||||
SaveIntegration("app2", []string{"model-2"})
|
||||
|
||||
config1, _ := LoadIntegration("app1")
|
||||
config2, _ := LoadIntegration("app2")
|
||||
|
||||
defaultModel1 := ""
|
||||
if len(config1.Models) > 0 {
|
||||
defaultModel1 = config1.Models[0]
|
||||
}
|
||||
defaultModel2 := ""
|
||||
if len(config2.Models) > 0 {
|
||||
defaultModel2 = config2.Models[0]
|
||||
}
|
||||
if defaultModel1 != "model-1" {
|
||||
t.Errorf("expected model-1, got %s", defaultModel1)
|
||||
}
|
||||
if defaultModel2 != "model-2" {
|
||||
t.Errorf("expected model-2, got %s", defaultModel2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIntegrations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty when no integrations", func(t *testing.T) {
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 0 {
|
||||
t.Errorf("expected 0 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all saved integrations", func(t *testing.T) {
|
||||
SaveIntegration("claude", []string{"model-1"})
|
||||
SaveIntegration("droid", []string{"model-2"})
|
||||
|
||||
configs, err := listIntegrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(configs) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(configs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
_, err := LoadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_NilModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := SaveIntegration("test", nil); err != nil {
|
||||
t.Fatalf("saveIntegration with nil models failed: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadIntegration("test")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
if config.Models == nil {
|
||||
// nil is acceptable
|
||||
} else if len(config.Models) != 0 {
|
||||
t.Errorf("expected empty or nil models, got %v", config.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIntegration_EmptyAppName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
err := SaveIntegration("", []string{"model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name, got nil")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
|
||||
t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
_, err := LoadIntegration("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration, got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Logf("error type is os.ErrNotExist as expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns empty config when file does not exist", func(t *testing.T) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected non-nil config")
|
||||
}
|
||||
if cfg.Integrations == nil {
|
||||
t.Error("expected non-nil Integrations map")
|
||||
}
|
||||
if len(cfg.Integrations) != 0 {
|
||||
t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loads existing config", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["test"] == nil {
|
||||
t.Fatal("expected test integration")
|
||||
}
|
||||
if len(cfg.Integrations["test"].Models) != 1 {
|
||||
t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for corrupted JSON", func(t *testing.T) {
|
||||
path, _ := configPath()
|
||||
os.MkdirAll(filepath.Dir(path), 0o755)
|
||||
os.WriteFile(path, []byte(`{corrupted`), 0o644)
|
||||
|
||||
_, err := load()
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupted JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("creates config file", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"test": {Models: []string{"model-a", "model-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
path, _ := configPath()
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("config file was not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("round-trip preserves data", func(t *testing.T) {
|
||||
cfg := &config{
|
||||
Integrations: map[string]*integration{
|
||||
"claude": {Models: []string{"llama3.2", "mistral"}},
|
||||
"codex": {Models: []string{"qwen2.5"}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(loaded.Integrations) != 2 {
|
||||
t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
|
||||
}
|
||||
if loaded.Integrations["claude"] == nil {
|
||||
t.Error("missing claude integration")
|
||||
}
|
||||
if len(loaded.Integrations["claude"].Models) != 2 {
|
||||
t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
5
cmd/editor_unix.go
Normal file
5
cmd/editor_unix.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
const defaultEditor = "vi"
|
||||
5
cmd/editor_windows.go
Normal file
5
cmd/editor_windows.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
const defaultEditor = "edit"
|
||||
735
cmd/interactive.go
Normal file
735
cmd/interactive.go
Normal file
@@ -0,0 +1,735 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilinePrompt
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, .webp images, or .wav audio files.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageSet := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShortcuts := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)")
|
||||
fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word")
|
||||
fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShow := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /show info Show details for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show license Show model license")
|
||||
fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model")
|
||||
fmt.Fprintln(os.Stderr, " /show system Show system message")
|
||||
fmt.Fprintln(os.Stderr, " /show template Show prompt template")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
// only list out the most common parameters
|
||||
usageParameters := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Parameters:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter min_p <float> Pick token based on top token probability * min_p")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter stop <string> <string> ... Set the stop parameters")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if envconfig.NoHistory() {
|
||||
scanner.HistoryDisable()
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
|
||||
continue
|
||||
case errors.Is(err, readline.ErrEditPrompt):
|
||||
sb.Reset()
|
||||
content, err := editInExternalEditor(line)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
continue
|
||||
}
|
||||
scanner.Prefill = content
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, `"""`):
|
||||
line := strings.TrimPrefix(line, `"""`)
|
||||
line, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(line)
|
||||
if !ok {
|
||||
// no multiline terminating string; need more input
|
||||
fmt.Fprintln(&sb)
|
||||
multiline = MultilinePrompt
|
||||
scanner.Prompt.UseAlt = true
|
||||
}
|
||||
case scanner.Pasting:
|
||||
fmt.Fprintln(&sb, line)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/list"):
|
||||
args := strings.Fields(line)
|
||||
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /load <modelname>")
|
||||
continue
|
||||
}
|
||||
origOpts := opts.Copy()
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
opts.LoadedMessages = nil
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: opts.Model})
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
applyShowResponseToRunOptions(&opts, info)
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /save <modelname>")
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
req := NewCreateRequest(args[1], opts)
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
|
||||
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
opts.Messages = []api.Message{}
|
||||
if opts.System != "" {
|
||||
newMessage := api.Message{Role: "system", Content: opts.System}
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Cleared session context")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
opts.WordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
// TODO(drifkin): validate the level, could be model dependent
|
||||
// though... It will also be validated on the server once a call is
|
||||
// made.
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
opts.Think = &thinkValue
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
opts.Format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
opts.Format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
usageParameters()
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
opts.Options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
usageSet()
|
||||
continue
|
||||
}
|
||||
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
opts.System = sb.String() // for display in modelfile
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||
} else {
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageSet()
|
||||
}
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
return err
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
_ = showInfo(resp, false, os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified for this model.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf(" %-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case opts.System != "":
|
||||
fmt.Println(opts.System + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "set", "/set":
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
}
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/"):
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
||||
if opts.MultiModal {
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isFile {
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
|
||||
continue
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
default:
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
|
||||
if opts.MultiModal {
|
||||
msg, images, err := extractFileData(sb.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newMessage.Content = msg
|
||||
newMessage.Images = images
|
||||
}
|
||||
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||
strings.Contains(err.Error(), "invalid think value") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||
parentModel := opts.ParentModel
|
||||
|
||||
modelName := model.ParseName(parentModel)
|
||||
if !modelName.IsValid() {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
// Preserve explicit cloud intent for sessions started with `:cloud`.
|
||||
// Cloud model metadata can return a source-less parent_model (for example
|
||||
// "qwen3.5"), which would otherwise make `/save` create a local derivative.
|
||||
if modelref.HasExplicitCloudSource(opts.Model) && !modelref.HasExplicitCloudSource(parentModel) {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Model: name,
|
||||
From: cmp.Or(parentModel, opts.Model),
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
req.System = opts.System
|
||||
}
|
||||
|
||||
if len(opts.Options) > 0 {
|
||||
req.Parameters = opts.Options
|
||||
}
|
||||
|
||||
messages := slices.Clone(opts.LoadedMessages)
|
||||
messages = append(messages, opts.Messages...)
|
||||
if len(messages) > 0 {
|
||||
req.Messages = messages
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func normalizeFilePath(fp string) string {
|
||||
return strings.NewReplacer(
|
||||
"\\ ", " ", // Escaped space
|
||||
"\\(", "(", // Escaped left parenthesis
|
||||
"\\)", ")", // Escaped right parenthesis
|
||||
"\\[", "[", // Escaped left square bracket
|
||||
"\\]", "]", // Escaped right square bracket
|
||||
"\\{", "{", // Escaped left curly brace
|
||||
"\\}", "}", // Escaped right curly brace
|
||||
"\\$", "$", // Escaped dollar sign
|
||||
"\\&", "&", // Escaped ampersand
|
||||
"\\;", ";", // Escaped semicolon
|
||||
"\\'", "'", // Escaped single quote
|
||||
"\\\\", "\\", // Escaped backslash
|
||||
"\\*", "*", // Escaped asterisk
|
||||
"\\?", "?", // Escaped question mark
|
||||
"\\~", "~", // Escaped tilde
|
||||
).Replace(fp)
|
||||
}
|
||||
|
||||
func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp|wav)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
}
|
||||
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
nfp := normalizeFilePath(fp)
|
||||
data, err := getImageData(nfp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process file: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(nfp))
|
||||
switch ext {
|
||||
case ".wav":
|
||||
fmt.Fprintf(os.Stderr, "Added audio '%s'\n", nfp)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
}
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return strings.TrimSpace(input), imgs, nil
|
||||
}
|
||||
|
||||
func editInExternalEditor(content string) (string, error) {
|
||||
editor := envconfig.Editor()
|
||||
if editor == "" {
|
||||
editor = os.Getenv("VISUAL")
|
||||
}
|
||||
if editor == "" {
|
||||
editor = os.Getenv("EDITOR")
|
||||
}
|
||||
if editor == "" {
|
||||
editor = defaultEditor
|
||||
}
|
||||
|
||||
// Check that the editor binary exists
|
||||
name := strings.Fields(editor)[0]
|
||||
if _, err := exec.LookPath(name); err != nil {
|
||||
return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if content != "" {
|
||||
if _, err := tmpFile.WriteString(content); err != nil {
|
||||
tmpFile.Close()
|
||||
return "", fmt.Errorf("writing to temp file: %w", err)
|
||||
}
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
args := strings.Fields(editor)
|
||||
args = append(args, tmpFile.Name())
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("editor exited with error: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tmpFile.Name())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading temp file: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimRight(string(data), "\n"), nil
|
||||
}
|
||||
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp", "audio/wave"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid file type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB
|
||||
if info.Size() > maxSize {
|
||||
return nil, errors.New("file size exceeds maximum limit (100MB)")
|
||||
}
|
||||
|
||||
buf = make([]byte, info.Size())
|
||||
_, err = file.Seek(0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(file, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
116
cmd/interactive_test.go
Normal file
116
cmd/interactive_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
|
||||
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 7)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[5], "six.webp")
|
||||
assert.Contains(t, res[6], "seven.WEBP")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbetween1")
|
||||
assert.NotContains(t, res, "./1.svg")
|
||||
|
||||
// Windows style paths
|
||||
input = ` some preamble
|
||||
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
|
||||
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
|
||||
d:\path with\spaces\thirteen.WEBP some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 13)
|
||||
assert.NotContains(t, res, "inbetween2")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[1], "c:")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[5], "six.png")
|
||||
assert.Contains(t, res[6], "seven.JPEG")
|
||||
assert.Contains(t, res[6], "d:")
|
||||
assert.Contains(t, res[7], "eight.png")
|
||||
assert.Contains(t, res[7], "c:")
|
||||
assert.Contains(t, res[8], "nine.png")
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.PNG")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
assert.Contains(t, res[10], "eleven.webp")
|
||||
assert.Contains(t, res[10], "c:")
|
||||
assert.Contains(t, res[11], "twelve.WebP")
|
||||
assert.Contains(t, res[11], "c:")
|
||||
assert.Contains(t, res[12], "thirteen.WEBP")
|
||||
assert.Contains(t, res[12], "d:")
|
||||
}
|
||||
|
||||
// Ensure that file paths wrapped in single quotes are removed with the quotes.
|
||||
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "img.jpg")
|
||||
data := make([]byte, 600)
|
||||
copy(data, []byte{
|
||||
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
|
||||
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0xff, 0xd9,
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test image: %v", err)
|
||||
}
|
||||
|
||||
input := "before '" + fp + "' after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
||||
func TestExtractFileDataWAV(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "sample.wav")
|
||||
data := make([]byte, 600)
|
||||
copy(data[:44], []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x58, 0x02, 0x00, 0x00, // file size - 8
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // fmt chunk size
|
||||
0x01, 0x00, // PCM
|
||||
0x01, 0x00, // mono
|
||||
0x80, 0x3e, 0x00, 0x00, // 16000 Hz
|
||||
0x00, 0x7d, 0x00, 0x00, // byte rate
|
||||
0x02, 0x00, // block align
|
||||
0x10, 0x00, // 16-bit
|
||||
'd', 'a', 't', 'a',
|
||||
0x34, 0x02, 0x00, 0x00, // data size
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test audio: %v", err)
|
||||
}
|
||||
|
||||
input := "before " + fp + " after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, "before after", cleaned)
|
||||
}
|
||||
176
cmd/internal/fileutil/files.go
Normal file
176
cmd/internal/fileutil/files.go
Normal file
@@ -0,0 +1,176 @@
|
||||
// Package fileutil provides small shared helpers for reading JSON files
|
||||
// and writing config files with backup-on-overwrite semantics.
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Keep a bounded number of backups per file so config backups do not grow
|
||||
// without limit. We keep the 5 most recent backups and do not pin the oldest.
|
||||
const maxBackupsPerFile = 5
|
||||
|
||||
// ReadJSON reads a JSON object file into a generic map.
|
||||
func ReadJSON(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, data, info.Mode().Perm())
|
||||
}
|
||||
|
||||
// BackupDir returns the shared backup root used before overwriting files.
|
||||
func BackupDir() string {
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
return filepath.Join(home, ".ollama", "backup")
|
||||
}
|
||||
return filepath.Join(os.TempDir(), "ollama-backup")
|
||||
}
|
||||
|
||||
func writeBackupCopy(srcPath string, integration string) (string, error) {
|
||||
dir := BackupDir()
|
||||
name := filepath.Base(srcPath)
|
||||
if integration != "" {
|
||||
dir = filepath.Join(dir, integration)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", name, time.Now().Unix()))
|
||||
if err := copyFile(srcPath, backupPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
pruneOldBackups(dir, name, maxBackupsPerFile)
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
// WriteWithBackup writes data to path via temp file + rename, backing up any
|
||||
// existing file first. Callers may optionally pass one integration name to
|
||||
// store backups under BackupDir()/.../<integration>/.
|
||||
func WriteWithBackup(path string, data []byte, integration ...string) error {
|
||||
backupIntegration := ""
|
||||
if len(integration) > 0 {
|
||||
backupIntegration = integration[0]
|
||||
}
|
||||
|
||||
var backupPath string
|
||||
// backup must be created before any writes to the target file
|
||||
if existingContent, err := os.ReadFile(path); err == nil {
|
||||
if bytes.Equal(existingContent, data) {
|
||||
return nil
|
||||
}
|
||||
backupPath, err = writeBackupCopy(path, backupIntegration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp failed: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("sync failed: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("close failed: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
if backupPath != "" {
|
||||
_ = copyFile(backupPath, path)
|
||||
}
|
||||
return fmt.Errorf("rename failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pruneOldBackups(dir, name string, keep int) {
|
||||
if keep < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type backupEntry struct {
|
||||
name string
|
||||
timestamp int64
|
||||
}
|
||||
|
||||
prefix := name + "."
|
||||
backups := make([]backupEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasPrefix(entry.Name(), prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
timestamp, err := strconv.ParseInt(strings.TrimPrefix(entry.Name(), prefix), 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
backups = append(backups, backupEntry{
|
||||
name: entry.Name(),
|
||||
timestamp: timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
if len(backups) <= keep {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
if backups[i].timestamp != backups[j].timestamp {
|
||||
return backups[i].timestamp > backups[j].timestamp
|
||||
}
|
||||
return backups[i].name > backups[j].name
|
||||
})
|
||||
|
||||
for _, backup := range backups[keep:] {
|
||||
_ = os.Remove(filepath.Join(dir, backup.name))
|
||||
}
|
||||
}
|
||||
627
cmd/internal/fileutil/files_test.go
Normal file
627
cmd/internal/fileutil/files_test.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
tmpRoot, err := os.MkdirTemp("", "fileutil-test-*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := os.Setenv("HOME", tmpRoot); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := os.Setenv("USERPROFILE", tmpRoot); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
_ = os.RemoveAll(tmpRoot)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func mustMarshal(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
data, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func isolatedTempDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
return t.TempDir()
|
||||
}
|
||||
|
||||
func TestWriteWithBackup(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
t.Run("uses ollama directory under home", func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
|
||||
want := filepath.Join(home, ".ollama", "backup")
|
||||
if got := BackupDir(); got != want {
|
||||
t.Fatalf("BackupDir() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "new.json")
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(content, &result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result["key"] != "value" {
|
||||
t.Errorf("expected value, got %s", result["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates backup in the shared backup directory", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "backup.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(BackupDir())
|
||||
if err != nil {
|
||||
t.Fatal("backup directory not created")
|
||||
}
|
||||
|
||||
var foundBackup bool
|
||||
for _, entry := range entries {
|
||||
if filepath.Ext(entry.Name()) != ".json" {
|
||||
name := entry.Name()
|
||||
if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
|
||||
backupPath := filepath.Join(BackupDir(), name)
|
||||
backup, err := os.ReadFile(backupPath)
|
||||
if err == nil {
|
||||
var backupData map[string]bool
|
||||
json.Unmarshal(backup, &backupData)
|
||||
if backupData["original"] {
|
||||
foundBackup = true
|
||||
os.Remove(backupPath)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup file not created in backup directory")
|
||||
}
|
||||
|
||||
current, _ := os.ReadFile(path)
|
||||
var currentData map[string]bool
|
||||
json.Unmarshal(current, ¤tData)
|
||||
if !currentData["updated"] {
|
||||
t.Error("file doesn't contain updated data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stores hinted backups under a subdirectory", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "hinted.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
data := mustMarshal(t, map[string]bool{"updated": true})
|
||||
if err := WriteWithBackup(path, data, "openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(filepath.Join(BackupDir(), "openclaw"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("hinted.json.") && name[:len("hinted.json.")] == "hinted.json." {
|
||||
found = true
|
||||
_ = os.Remove(filepath.Join(BackupDir(), "openclaw", name))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Error("backup file was not created under hint directory")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup for new file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "nobak.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"new": "file"})
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
for _, entry := range entries {
|
||||
if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
|
||||
t.Error("backup should not exist for new file")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no backup when content unchanged", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "unchanged.json")
|
||||
|
||||
data := mustMarshal(t, map[string]string{"key": "value"})
|
||||
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries1, _ := os.ReadDir(BackupDir())
|
||||
countBefore := 0
|
||||
for _, e := range entries1 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countBefore++
|
||||
}
|
||||
}
|
||||
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries2, _ := os.ReadDir(BackupDir())
|
||||
countAfter := 0
|
||||
for _, e := range entries2 {
|
||||
if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
|
||||
countAfter++
|
||||
}
|
||||
}
|
||||
|
||||
if countAfter != countBefore {
|
||||
t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backup filename contains unix timestamp", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "timestamped.json")
|
||||
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
data := mustMarshal(t, map[string]int{"v": 2})
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var found bool
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
|
||||
timestamp := name[len("timestamped.json."):]
|
||||
for _, c := range timestamp {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
|
||||
}
|
||||
}
|
||||
found = true
|
||||
os.Remove(filepath.Join(BackupDir(), name))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("backup file with timestamp not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("retains only the five newest backups per file", func(t *testing.T) {
|
||||
path := filepath.Join(tmpDir, "pruned.json")
|
||||
if err := os.WriteFile(path, []byte(`{"v": 0}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxBackupsPerFile; i++ {
|
||||
backupPath := filepath.Join(BackupDir(), fmt.Sprintf("pruned.json.%d", i))
|
||||
if err := os.WriteFile(backupPath, []byte(fmt.Sprintf(`{"v": %d}`, i)), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := WriteWithBackup(path, []byte(`{"v": 1}`)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, err := filepath.Glob(filepath.Join(BackupDir(), "pruned.json.*"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(backups) != maxBackupsPerFile {
|
||||
t.Fatalf("expected %d backups after pruning, got %d", maxBackupsPerFile, len(backups))
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(BackupDir(), "pruned.json.1")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected oldest backup to be pruned, stat err = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Edge case tests for files.go
|
||||
|
||||
// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
|
||||
// User could lose their config with no way to recover.
|
||||
func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
// Create original file
|
||||
originalContent := []byte(`{"original": true}`)
|
||||
os.WriteFile(path, originalContent, 0o644)
|
||||
|
||||
// Make backup directory read-only to force backup failure
|
||||
backupDir := BackupDir()
|
||||
os.MkdirAll(backupDir, 0o755)
|
||||
os.Chmod(backupDir, 0o444) // Read-only
|
||||
defer os.Chmod(backupDir, 0o755)
|
||||
|
||||
newContent := []byte(`{"updated": true}`)
|
||||
err := WriteWithBackup(path, newContent)
|
||||
|
||||
// Should fail because backup couldn't be created
|
||||
if err == nil {
|
||||
t.Error("expected error when backup fails, got nil")
|
||||
}
|
||||
|
||||
// Original file should be preserved
|
||||
current, _ := os.ReadFile(path)
|
||||
if string(current) != string(originalContent) {
|
||||
t.Errorf("original file was modified despite backup failure: got %s", string(current))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
|
||||
// Common issue when config owned by root or wrong perms.
|
||||
func TestWriteWithBackup_PermissionDenied(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Create a read-only directory
|
||||
readOnlyDir := filepath.Join(tmpDir, "readonly")
|
||||
os.MkdirAll(readOnlyDir, 0o755)
|
||||
os.Chmod(readOnlyDir, 0o444)
|
||||
defer os.Chmod(readOnlyDir, 0o755)
|
||||
|
||||
path := filepath.Join(readOnlyDir, "config.json")
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected permission error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteWithBackup_UnchangedContentIsNoOp(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "unchanged-noop.json")
|
||||
data := []byte(`{"same":true}`)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpDir, 0o555); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Chmod(tmpDir, 0o755)
|
||||
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatalf("expected unchanged write to be a no-op, got %v", err)
|
||||
}
|
||||
|
||||
backups, err := filepath.Glob(filepath.Join(BackupDir(), "unchanged-noop.json.*"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(backups) != 0 {
|
||||
t.Fatalf("expected no backups for unchanged content, got %d", len(backups))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
|
||||
// writeWithBackup doesn't create directories - caller is responsible.
|
||||
func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
|
||||
|
||||
err := WriteWithBackup(path, []byte(`{"test": true}`))
|
||||
|
||||
// Should fail because directory doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
|
||||
// Documents what happens if user symlinks their config file.
|
||||
func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlink tests may require admin on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
realFile := filepath.Join(tmpDir, "real.json")
|
||||
symlink := filepath.Join(tmpDir, "link.json")
|
||||
|
||||
// Create real file and symlink
|
||||
os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
|
||||
os.Symlink(realFile, symlink)
|
||||
|
||||
// Write through symlink
|
||||
err := WriteWithBackup(symlink, []byte(`{"v": 2}`))
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup through symlink failed: %v", err)
|
||||
}
|
||||
|
||||
// The real file should be updated (symlink followed for temp file creation)
|
||||
content, _ := os.ReadFile(symlink)
|
||||
if string(content) != `{"v": 2}` {
|
||||
t.Errorf("symlink target not updated correctly: got %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
|
||||
// User may have config files with unusual names.
|
||||
func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// File with spaces and special chars
|
||||
path := filepath.Join(tmpDir, "my config (backup).json")
|
||||
os.WriteFile(path, []byte(`{"test": true}`), 0o644)
|
||||
|
||||
backupPath, err := writeBackupCopy(path, "")
|
||||
if err != nil {
|
||||
t.Fatalf("writeBackupCopy with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify backup exists and has correct content
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read backup: %v", err)
|
||||
}
|
||||
if string(content) != `{"test": true}` {
|
||||
t.Errorf("backup content mismatch: got %s", string(content))
|
||||
}
|
||||
|
||||
os.Remove(backupPath)
|
||||
}
|
||||
|
||||
// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
|
||||
func TestCopyFile_PreservesPermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission preservation tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "src.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
// Create source with specific permissions
|
||||
os.WriteFile(src, []byte(`{"test": true}`), 0o600)
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err != nil {
|
||||
t.Fatalf("copyFile failed: %v", err)
|
||||
}
|
||||
|
||||
srcInfo, _ := os.Stat(src)
|
||||
dstInfo, _ := os.Stat(dst)
|
||||
|
||||
if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
|
||||
t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
|
||||
func TestCopyFile_SourceNotFound(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
src := filepath.Join(tmpDir, "nonexistent.json")
|
||||
dst := filepath.Join(tmpDir, "dst.json")
|
||||
|
||||
err := copyFile(src, dst)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent source, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
|
||||
func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
dirPath := filepath.Join(tmpDir, "actualdir")
|
||||
os.MkdirAll(dirPath, 0o755)
|
||||
|
||||
err := WriteWithBackup(dirPath, []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when target is a directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
|
||||
func TestWriteWithBackup_EmptyData(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "empty.json")
|
||||
|
||||
err := WriteWithBackup(path, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("writeWithBackup with empty data failed: %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("could not read file: %v", err)
|
||||
}
|
||||
if len(content) != 0 {
|
||||
t.Errorf("expected empty file, got %d bytes", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
|
||||
// cannot be read (for backup comparison) but directory is writable.
|
||||
func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "unreadable.json")
|
||||
|
||||
// Create file and make it unreadable
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
os.Chmod(path, 0o000)
|
||||
defer os.Chmod(path, 0o644)
|
||||
|
||||
// Should fail because we can't read the file to compare/backup
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when file is unreadable, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
|
||||
// within the same second (timestamp collision scenario).
|
||||
func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
|
||||
tmpDir := isolatedTempDir(t)
|
||||
path := filepath.Join(tmpDir, "rapid.json")
|
||||
|
||||
// Create initial file
|
||||
os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
|
||||
|
||||
// Rapid successive writes
|
||||
for i := 1; i <= 3; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"v": %d}`, i))
|
||||
if err := WriteWithBackup(path, data); err != nil {
|
||||
t.Fatalf("write %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final content
|
||||
content, _ := os.ReadFile(path)
|
||||
if string(content) != `{"v": 3}` {
|
||||
t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
|
||||
}
|
||||
|
||||
// Verify at least one backup exists
|
||||
entries, _ := os.ReadDir(BackupDir())
|
||||
var backupCount int
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
|
||||
backupCount++
|
||||
}
|
||||
}
|
||||
if backupCount == 0 {
|
||||
t.Error("expected at least one backup file from rapid writes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
|
||||
func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("test modifies system temp directory")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
// Create a file at the backup directory path
|
||||
backupPath := BackupDir()
|
||||
// Clean up any existing directory first
|
||||
os.RemoveAll(backupPath)
|
||||
// Create a file instead of directory
|
||||
os.WriteFile(backupPath, []byte("not a directory"), 0o644)
|
||||
defer func() {
|
||||
os.Remove(backupPath)
|
||||
os.MkdirAll(backupPath, 0o755)
|
||||
}()
|
||||
|
||||
path := filepath.Join(tmpDir, "test.json")
|
||||
os.WriteFile(path, []byte(`{"original": true}`), 0o644)
|
||||
|
||||
err := WriteWithBackup(path, []byte(`{"updated": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error when backup dir is a file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
|
||||
func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission tests unreliable on Windows")
|
||||
}
|
||||
|
||||
tmpDir := isolatedTempDir(t)
|
||||
|
||||
// Count existing temp files
|
||||
countTempFiles := func() int {
|
||||
entries, _ := os.ReadDir(tmpDir)
|
||||
count := 0
|
||||
for _, e := range entries {
|
||||
if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
before := countTempFiles()
|
||||
|
||||
// Create a file, then make directory read-only to cause rename failure
|
||||
path := filepath.Join(tmpDir, "orphan.json")
|
||||
os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make a subdirectory and try to write there after making parent read-only
|
||||
subDir := filepath.Join(tmpDir, "subdir")
|
||||
os.MkdirAll(subDir, 0o755)
|
||||
subPath := filepath.Join(subDir, "config.json")
|
||||
os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
|
||||
|
||||
// Make subdir read-only after creating temp file would succeed but rename would fail
|
||||
// This is tricky to test - the temp file is created in the same dir, so if we can't
|
||||
// rename, we also couldn't create. Let's just verify normal failure cleanup works.
|
||||
|
||||
// Force a failure by making the target a directory
|
||||
badPath := filepath.Join(tmpDir, "isdir")
|
||||
os.MkdirAll(badPath, 0o755)
|
||||
|
||||
_ = WriteWithBackup(badPath, []byte(`{"test": true}`))
|
||||
|
||||
after := countTempFiles()
|
||||
if after > before {
|
||||
t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
371
cmd/launch/account.go
Normal file
371
cmd/launch/account.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultUpgradeURL is the fixed destination for subscription upgrades.
|
||||
DefaultUpgradeURL = "https://ollama.com/upgrade"
|
||||
|
||||
accountCheckTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPlanVerificationUnavailable = errors.New("Could not verify your plan. Try again in a moment.")
|
||||
errUpgradeCancelled = errors.New("upgrade cancelled")
|
||||
)
|
||||
|
||||
type accountStateStatus int
|
||||
|
||||
const (
|
||||
accountStateUnknown accountStateStatus = iota
|
||||
accountStateSignedOut
|
||||
accountStateSignedIn
|
||||
)
|
||||
|
||||
type AccountState struct {
|
||||
Status accountStateStatus
|
||||
Plan string
|
||||
}
|
||||
|
||||
type AccountStatePrefetch struct {
|
||||
done chan struct{}
|
||||
state AccountState
|
||||
}
|
||||
|
||||
func StartAccountStatePrefetch(ctx context.Context) *AccountStatePrefetch {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
p := &AccountStatePrefetch{done: make(chan struct{})}
|
||||
go func() {
|
||||
state := AccountState{Status: accountStateUnknown}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err == nil {
|
||||
prefetchCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
|
||||
defer cancel()
|
||||
if disabled, known := cloudStatusDisabled(prefetchCtx, client); !known || !disabled {
|
||||
state = launchAccountState(prefetchCtx, client)
|
||||
}
|
||||
}
|
||||
p.state = state
|
||||
close(p.done)
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *AccountStatePrefetch) StateIfReady() *AccountState {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
state := p.state
|
||||
return &state
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AccountStatePrefetch) StateUpdates(ctx context.Context) <-chan *AccountState {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
out := make(chan *AccountState, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
select {
|
||||
case <-p.done:
|
||||
if p.state.Status == accountStateUnknown {
|
||||
return
|
||||
}
|
||||
state := p.state
|
||||
select {
|
||||
case out <- &state:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func launchAccountState(ctx context.Context, client *api.Client) AccountState {
|
||||
if client == nil {
|
||||
return AccountState{Status: accountStateUnknown}
|
||||
}
|
||||
|
||||
user, err := whoamiWithTimeout(ctx, client)
|
||||
if err != nil {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) && authErr.StatusCode == http.StatusUnauthorized {
|
||||
return AccountState{Status: accountStateSignedOut}
|
||||
}
|
||||
return AccountState{Status: accountStateUnknown}
|
||||
}
|
||||
if user == nil || strings.TrimSpace(user.Name) == "" {
|
||||
return AccountState{Status: accountStateSignedOut}
|
||||
}
|
||||
return AccountState{
|
||||
Status: accountStateSignedIn,
|
||||
Plan: strings.TrimSpace(user.Plan),
|
||||
}
|
||||
}
|
||||
|
||||
func whoamiWithTimeout(ctx context.Context, client *api.Client) (*api.UserResponse, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
checkCtx, cancel := context.WithTimeout(ctx, accountCheckTimeout)
|
||||
defer cancel()
|
||||
return client.Whoami(checkCtx)
|
||||
}
|
||||
|
||||
func ApplyAccountStateToSelectionItems(items []ModelItem, state AccountState) []SelectionItem {
|
||||
out := make([]SelectionItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectionItem{
|
||||
Name: item.Name,
|
||||
Description: item.Description,
|
||||
Recommended: item.Recommended,
|
||||
AvailabilityBadge: availabilityBadge(item, state),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func SelectionItemsWithAccountState(items []ModelItem, state *AccountState) []SelectionItem {
|
||||
if state == nil || !selectionItemsNeedAccountState(items) {
|
||||
return ApplyAccountStateToSelectionItems(items, AccountState{Status: accountStateUnknown})
|
||||
}
|
||||
return ApplyAccountStateToSelectionItems(items, *state)
|
||||
}
|
||||
|
||||
func selectionItemsNeedAccountState(items []ModelItem) bool {
|
||||
for _, item := range items {
|
||||
if isCloudModelName(item.Name) && itemHasRecommendationMetadata(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *launcherClient) selectionItemUpdates(ctx context.Context, items []ModelItem, state *AccountState) <-chan []SelectionItem {
|
||||
if !selectionItemsNeedAccountState(items) || state != nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
stateUpdates := c.accountStateUpdateSource(ctx)
|
||||
if stateUpdates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make(chan []SelectionItem, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
select {
|
||||
case state, ok := <-stateUpdates:
|
||||
if !ok || state == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case out <- SelectionItemsWithAccountState(items, state):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *launcherClient) accountStateUpdateSource(ctx context.Context) <-chan *AccountState {
|
||||
if c.accountStateUpdates != nil {
|
||||
return c.accountStateUpdates(ctx)
|
||||
}
|
||||
if c.apiClient == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(chan *AccountState, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case out <- &state:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func availabilityBadge(item ModelItem, state AccountState) string {
|
||||
if !isCloudModelName(item.Name) {
|
||||
return ""
|
||||
}
|
||||
switch state.Status {
|
||||
case accountStateSignedOut:
|
||||
if itemHasRecommendationMetadata(item) {
|
||||
return "Sign in required"
|
||||
}
|
||||
case accountStateSignedIn:
|
||||
if item.RequiredPlan != "" && !PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return "Upgrade required"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func itemHasRecommendationMetadata(item ModelItem) bool {
|
||||
return item.Recommended || strings.TrimSpace(item.RequiredPlan) != ""
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureCloudModelAccess(ctx context.Context, model string) error {
|
||||
item, ok := c.modelRecommendationItem(ctx, model)
|
||||
if !ok || strings.TrimSpace(item.RequiredPlan) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status != accountStateUnknown {
|
||||
c.accountState = &state
|
||||
}
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
|
||||
if state.Status == accountStateSignedOut {
|
||||
if err := ensureCloudAuth(ctx, c.apiClient, model); err != nil {
|
||||
return err
|
||||
}
|
||||
state = launchAccountState(ctx, c.apiClient)
|
||||
if state.Status != accountStateUnknown {
|
||||
c.accountState = &state
|
||||
}
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.runUpgradeFlow(ctx, item); err != nil {
|
||||
return err
|
||||
}
|
||||
state = launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
if state.Status != accountStateSignedIn || !PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) modelRecommendationItem(ctx context.Context, model string) (ModelItem, bool) {
|
||||
for _, item := range c.recommendations(ctx) {
|
||||
if item.Name == model {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return ModelItem{}, false
|
||||
}
|
||||
|
||||
func (c *launcherClient) runUpgradeFlow(ctx context.Context, item ModelItem) error {
|
||||
if DefaultUpgrade != nil {
|
||||
if _, err := DefaultUpgrade(item.Name, item.RequiredPlan); err != nil {
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("Upgrade to use %s?", item.Name))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo upgrade, navigate to:\n %s\n\n", DefaultUpgradeURL)
|
||||
openNow, err := ConfirmPrompt("Open now?")
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if openNow {
|
||||
OpenBrowser(DefaultUpgradeURL)
|
||||
} else {
|
||||
return errUpgradeCancelled
|
||||
}
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for upgrade to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
if frame%10 != 0 {
|
||||
continue
|
||||
}
|
||||
state := launchAccountState(ctx, c.apiClient)
|
||||
if state.Status == accountStateUnknown {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ErrPlanVerificationUnavailable
|
||||
}
|
||||
if state.Status == accountStateSignedIn && PlanSatisfies(state.Plan, item.RequiredPlan) {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1mplan updated\033[0m\n")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PlanSatisfies reports whether currentPlan can use a model that has a requiredPlan.
|
||||
func PlanSatisfies(currentPlan, requiredPlan string) bool {
|
||||
required := normalizePlan(requiredPlan)
|
||||
if required == "" || required == "free" {
|
||||
return true
|
||||
}
|
||||
current := normalizePlan(currentPlan)
|
||||
return current != "" && current != "free"
|
||||
}
|
||||
|
||||
func normalizePlan(plan string) string {
|
||||
return strings.ToLower(strings.TrimSpace(plan))
|
||||
}
|
||||
87
cmd/launch/claude.go
Normal file
87
cmd/launch/claude.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration.
|
||||
type Claude struct{}
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, _ []LaunchModel, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
env := append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
cmd.Env = env
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
|
||||
func (c *Claude) modelEnvVars(model string) []string {
|
||||
env := []string{
|
||||
"ANTHROPIC_DEFAULT_OPUS_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_SONNET_MODEL=" + model,
|
||||
"ANTHROPIC_DEFAULT_HAIKU_MODEL=" + model,
|
||||
"CLAUDE_CODE_SUBAGENT_MODEL=" + model,
|
||||
}
|
||||
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
env = append(env, "CLAUDE_CODE_AUTO_COMPACT_WINDOW="+strconv.Itoa(l.Context))
|
||||
}
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
888
cmd/launch/claude_desktop.go
Normal file
888
cmd/launch/claude_desktop.go
Normal file
@@ -0,0 +1,888 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeDesktopIntegrationName = "claude-desktop"
|
||||
claudeDesktopProfileName = "Ollama"
|
||||
claudeDesktopProfileID = "00000000-0000-4000-8000-000000000114"
|
||||
claudeDesktopGatewayBaseURL = "https://ollama.com"
|
||||
claudeDesktopAPIKeyURL = "https://ollama.com/settings/keys"
|
||||
claudeDesktopModelLabel = "Ollama Cloud"
|
||||
claudeDesktopUnsupported = "Claude Desktop is no longer supported. Existing installations can be restored with 'ollama launch claude-desktop --restore'."
|
||||
claudeDesktopSuccessMessage = "Claude Desktop profile changed to Ollama Cloud."
|
||||
claudeDesktopRestoreMessage = "To restore the usual Claude profile, run: ollama launch claude-desktop --restore"
|
||||
claudeDesktopRestoredMessage = "Claude Desktop restored to the usual Claude profile."
|
||||
)
|
||||
|
||||
var (
|
||||
claudeDesktopGOOS = runtime.GOOS
|
||||
claudeDesktopUserHome = os.UserHomeDir
|
||||
claudeDesktopStat = os.Stat
|
||||
claudeDesktopOpenApp = defaultClaudeDesktopOpenApp
|
||||
claudeDesktopOpenAppPath = defaultClaudeDesktopOpenAppPath
|
||||
claudeDesktopQuitApp = defaultClaudeDesktopQuitApp
|
||||
claudeDesktopIsRunning = defaultClaudeDesktopIsRunning
|
||||
claudeDesktopRunningAppPath = defaultClaudeDesktopRunningAppPath
|
||||
claudeDesktopGlob = filepath.Glob
|
||||
claudeDesktopSleep = time.Sleep
|
||||
claudeDesktopHTTPClient = http.DefaultClient
|
||||
claudeDesktopPromptAPIKey = promptClaudeDesktopAPIKey
|
||||
claudeDesktopValidateAPIKey = validateClaudeDesktopAPIKey
|
||||
)
|
||||
|
||||
// ClaudeDesktop configures and launches Claude Desktop in third-party
|
||||
// inference mode using Ollama Cloud as the gateway.
|
||||
type ClaudeDesktop struct{}
|
||||
|
||||
func (c *ClaudeDesktop) String() string { return "Claude Desktop" }
|
||||
|
||||
func (c *ClaudeDesktop) Supported() error { return claudeDesktopSupported() }
|
||||
|
||||
func (c *ClaudeDesktop) Paths() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) AutodiscoveredModel() string {
|
||||
return claudeDesktopModelLabel
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) ConfigureAutodiscovery() error {
|
||||
if err := claudeDesktopSupported(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := claudeDesktopValidatedAPIKey(context.Background(), claudeDesktopTargetProfilePaths(targets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, path := range targets.normalConfigs {
|
||||
if err := writeClaudeDesktopDeploymentMode(path, "3p"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
if err := writeClaudeDesktopDeploymentMode(target.desktopConfig, "3p"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeClaudeDesktopMeta(target.meta, claudeDesktopProfileID, claudeDesktopProfileName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeClaudeDesktopGatewayProfile(target.profile, key, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) RestoreHint() string {
|
||||
return claudeDesktopRestoreMessage
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) ConfigurationSuccessMessage() string {
|
||||
return claudeDesktopSuccessMessage + "\n" + claudeDesktopRestoreMessage
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) RestoreSuccessMessage() string {
|
||||
return claudeDesktopRestoredMessage
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) AutodiscoveryConfigured() bool {
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return claudeDesktopTargetsConfigured(targets)
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) Onboard() error {
|
||||
return config.MarkIntegrationOnboarded(claudeDesktopIntegrationName)
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) RequiresInteractiveOnboarding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) SkipModelReadiness() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) Run(_ string, _ []LaunchModel, _ []string) error {
|
||||
return errClaudeDesktopUnsupported()
|
||||
}
|
||||
|
||||
func (c *ClaudeDesktop) Restore() error {
|
||||
if err := claudeDesktopSupported(); err != nil {
|
||||
return err
|
||||
}
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, path := range targets.normalConfigs {
|
||||
if err := writeClaudeDesktopDeploymentMode(path, "1p"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
if err := writeClaudeDesktopDeploymentMode(target.desktopConfig, "1p"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restoreClaudeDesktopMeta(target.meta); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restoreClaudeDesktopOllamaProfile(target.profile); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return claudeDesktopLaunchOrRestart("Restart Claude Desktop to use the usual Claude profile?")
|
||||
}
|
||||
|
||||
func errClaudeDesktopUnsupported() error {
|
||||
return errors.New(claudeDesktopUnsupported)
|
||||
}
|
||||
|
||||
func claudeDesktopSupported() error {
|
||||
switch claudeDesktopGOOS {
|
||||
case "darwin", "windows":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Claude Desktop launch is only supported on macOS and Windows")
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDesktopInstalled() bool {
|
||||
if claudeDesktopAppPath() != "" {
|
||||
return true
|
||||
}
|
||||
if claudeDesktopGOOS == "windows" && claudeDesktopIsRunning() {
|
||||
return true
|
||||
}
|
||||
for _, dir := range claudeDesktopProfileDirCandidates(false) {
|
||||
if _, err := claudeDesktopStat(dir); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func claudeDesktopAppPath() string {
|
||||
if claudeDesktopGOOS != "darwin" && claudeDesktopGOOS != "windows" {
|
||||
return ""
|
||||
}
|
||||
for _, path := range claudeDesktopAppCandidates() {
|
||||
if _, err := claudeDesktopStat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func claudeDesktopAppCandidates() []string {
|
||||
switch claudeDesktopGOOS {
|
||||
case "darwin":
|
||||
return claudeDesktopDarwinAppCandidates()
|
||||
case "windows":
|
||||
return claudeDesktopWindowsAppCandidates()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDesktopDarwinAppCandidates() []string {
|
||||
candidates := []string{"/Applications/Claude.app"}
|
||||
if home, err := claudeDesktopUserHome(); err == nil {
|
||||
candidates = append(candidates, filepath.Join(home, "Applications", "Claude.app"))
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func claudeDesktopWindowsAppCandidates() []string {
|
||||
local, err := claudeDesktopLocalAppData()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
filepath.Join(local, "Programs", "Claude", "Claude.exe"),
|
||||
filepath.Join(local, "Programs", "Claude Desktop", "Claude.exe"),
|
||||
filepath.Join(local, "Claude", "Claude.exe"),
|
||||
filepath.Join(local, "Claude Nest", "Claude.exe"),
|
||||
filepath.Join(local, "Claude Desktop", "Claude.exe"),
|
||||
filepath.Join(local, "AnthropicClaude", "Claude.exe"),
|
||||
}
|
||||
for _, pattern := range []string{
|
||||
filepath.Join(local, "AnthropicClaude", "app-*", "Claude.exe"),
|
||||
filepath.Join(local, "Programs", "Claude", "app-*", "Claude.exe"),
|
||||
filepath.Join(local, "Programs", "Claude Desktop", "app-*", "Claude.exe"),
|
||||
} {
|
||||
matches, _ := claudeDesktopGlob(pattern)
|
||||
candidates = append(candidates, matches...)
|
||||
}
|
||||
return claudeDesktopDedupePaths(candidates)
|
||||
}
|
||||
|
||||
func claudeDesktopDedupePaths(paths []string) []string {
|
||||
out := make([]string, 0, len(paths))
|
||||
seen := make(map[string]bool, len(paths))
|
||||
for _, path := range paths {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(path)
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
out = append(out, path)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type claudeDesktopPaths struct {
|
||||
normalConfig string
|
||||
desktopConfig string
|
||||
meta string
|
||||
profile string
|
||||
}
|
||||
|
||||
type claudeDesktopThirdPartyPaths struct {
|
||||
desktopConfig string
|
||||
meta string
|
||||
profile string
|
||||
}
|
||||
|
||||
type claudeDesktopTargets struct {
|
||||
normalConfigs []string
|
||||
thirdPartyProfiles []claudeDesktopThirdPartyPaths
|
||||
}
|
||||
|
||||
func claudeDesktopConfigPaths() (claudeDesktopPaths, error) {
|
||||
switch claudeDesktopGOOS {
|
||||
case "darwin":
|
||||
return claudeDesktopDarwinConfigPaths()
|
||||
case "windows":
|
||||
return claudeDesktopWindowsConfigPaths()
|
||||
default:
|
||||
return claudeDesktopPaths{}, claudeDesktopSupported()
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDesktopDarwinConfigPaths() (claudeDesktopPaths, error) {
|
||||
normalRoots, thirdPartyRoots, err := claudeDesktopDarwinProfileRoots()
|
||||
if err != nil {
|
||||
return claudeDesktopPaths{}, err
|
||||
}
|
||||
normalBase := normalRoots[0]
|
||||
thirdPartyBase := thirdPartyRoots[0]
|
||||
return claudeDesktopPaths{
|
||||
normalConfig: filepath.Join(normalBase, "claude_desktop_config.json"),
|
||||
desktopConfig: filepath.Join(thirdPartyBase, "claude_desktop_config.json"),
|
||||
meta: filepath.Join(thirdPartyBase, "configLibrary", "_meta.json"),
|
||||
profile: filepath.Join(thirdPartyBase, "configLibrary", claudeDesktopProfileID+".json"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func claudeDesktopWindowsConfigPaths() (claudeDesktopPaths, error) {
|
||||
normalBase, err := claudeDesktopProfileDir(true)
|
||||
if err != nil {
|
||||
return claudeDesktopPaths{}, err
|
||||
}
|
||||
thirdPartyBase, err := claudeDesktopProfileDir(false)
|
||||
if err != nil {
|
||||
return claudeDesktopPaths{}, err
|
||||
}
|
||||
return claudeDesktopPaths{
|
||||
normalConfig: filepath.Join(normalBase, "claude_desktop_config.json"),
|
||||
desktopConfig: filepath.Join(thirdPartyBase, "claude_desktop_config.json"),
|
||||
meta: filepath.Join(thirdPartyBase, "configLibrary", "_meta.json"),
|
||||
profile: filepath.Join(thirdPartyBase, "configLibrary", claudeDesktopProfileID+".json"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func claudeDesktopProfileDir(normal bool) (string, error) {
|
||||
candidates := claudeDesktopProfileDirCandidates(normal)
|
||||
if len(candidates) == 0 {
|
||||
return "", fmt.Errorf("Claude Desktop profile directory could not be resolved")
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if _, err := claudeDesktopStat(candidate); err == nil {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
return candidates[0], nil
|
||||
}
|
||||
|
||||
func claudeDesktopProfileDirCandidates(normal bool) []string {
|
||||
if claudeDesktopGOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
normalRoots, thirdPartyRoots, err := claudeDesktopWindowsProfileRoots()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if normal {
|
||||
return normalRoots
|
||||
}
|
||||
return thirdPartyRoots
|
||||
}
|
||||
|
||||
func claudeDesktopDarwinProfileRoots() ([]string, []string, error) {
|
||||
home, err := claudeDesktopUserHome()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
base := filepath.Join(home, "Library", "Application Support")
|
||||
return []string{filepath.Join(base, "Claude")}, []string{filepath.Join(base, "Claude-3p")}, nil
|
||||
}
|
||||
|
||||
func claudeDesktopWindowsProfileRoots() ([]string, []string, error) {
|
||||
local, err := claudeDesktopLocalAppData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
normalRoots := []string{
|
||||
filepath.Join(local, "Claude"),
|
||||
filepath.Join(local, "Claude Nest"),
|
||||
}
|
||||
thirdPartyRoots := []string{
|
||||
filepath.Join(local, "Claude-3p"),
|
||||
filepath.Join(local, "Claude Nest-3p"),
|
||||
}
|
||||
return normalRoots, thirdPartyRoots, nil
|
||||
}
|
||||
|
||||
func claudeDesktopTargetPaths() (claudeDesktopTargets, error) {
|
||||
var (
|
||||
normalRoots []string
|
||||
thirdPartyRoots []string
|
||||
err error
|
||||
)
|
||||
|
||||
switch claudeDesktopGOOS {
|
||||
case "darwin":
|
||||
normalRoots, thirdPartyRoots, err = claudeDesktopDarwinProfileRoots()
|
||||
case "windows":
|
||||
normalRoots, thirdPartyRoots, err = claudeDesktopWindowsProfileRoots()
|
||||
default:
|
||||
err = claudeDesktopSupported()
|
||||
}
|
||||
if err != nil {
|
||||
return claudeDesktopTargets{}, err
|
||||
}
|
||||
|
||||
return newClaudeDesktopTargets(normalRoots, thirdPartyRoots), nil
|
||||
}
|
||||
|
||||
func newClaudeDesktopTargets(normalRoots, thirdPartyRoots []string) claudeDesktopTargets {
|
||||
targets := claudeDesktopTargets{}
|
||||
for _, root := range claudeDesktopDedupePaths(normalRoots) {
|
||||
targets.normalConfigs = append(targets.normalConfigs, filepath.Join(root, "claude_desktop_config.json"))
|
||||
}
|
||||
for _, root := range claudeDesktopDedupePaths(thirdPartyRoots) {
|
||||
targets.thirdPartyProfiles = append(targets.thirdPartyProfiles, claudeDesktopThirdPartyPaths{
|
||||
desktopConfig: filepath.Join(root, "claude_desktop_config.json"),
|
||||
meta: filepath.Join(root, "configLibrary", "_meta.json"),
|
||||
profile: filepath.Join(root, "configLibrary", claudeDesktopProfileID+".json"),
|
||||
})
|
||||
}
|
||||
return targets
|
||||
}
|
||||
|
||||
func claudeDesktopTargetProfilePaths(targets claudeDesktopTargets) []string {
|
||||
paths := make([]string, 0, len(targets.thirdPartyProfiles))
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
paths = append(paths, target.profile)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func claudeDesktopLocalAppData() (string, error) {
|
||||
if local := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); local != "" {
|
||||
return local, nil
|
||||
}
|
||||
if home := strings.TrimSpace(os.Getenv("USERPROFILE")); home != "" {
|
||||
return filepath.Join(home, "AppData", "Local"), nil
|
||||
}
|
||||
home, err := claudeDesktopUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, "AppData", "Local"), nil
|
||||
}
|
||||
|
||||
type claudeDesktopAPIKeySource int
|
||||
|
||||
const (
|
||||
claudeDesktopAPIKeySourceNone claudeDesktopAPIKeySource = iota
|
||||
claudeDesktopAPIKeySourceEnv
|
||||
claudeDesktopAPIKeySourceProfile
|
||||
)
|
||||
|
||||
func claudeDesktopValidatedAPIKey(ctx context.Context, profilePaths []string) (string, error) {
|
||||
key, source, err := claudeDesktopAPIKey(profilePaths)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := claudeDesktopValidateAPIKey(ctx, key); err == nil {
|
||||
return key, nil
|
||||
} else if source != claudeDesktopAPIKeySourceProfile || !canPromptClaudeDesktopAPIKey() {
|
||||
return "", err
|
||||
}
|
||||
return promptValidClaudeDesktopAPIKey(ctx)
|
||||
}
|
||||
|
||||
func claudeDesktopAPIKey(profilePaths []string) (string, claudeDesktopAPIKeySource, error) {
|
||||
if key := strings.TrimSpace(os.Getenv("OLLAMA_API_KEY")); key != "" {
|
||||
return key, claudeDesktopAPIKeySourceEnv, nil
|
||||
}
|
||||
for _, profilePath := range profilePaths {
|
||||
if key := readClaudeDesktopGatewayAPIKey(profilePath); key != "" {
|
||||
return key, claudeDesktopAPIKeySourceProfile, nil
|
||||
}
|
||||
}
|
||||
key, err := promptClaudeDesktopAPIKeyValue()
|
||||
return key, claudeDesktopAPIKeySourceNone, err
|
||||
}
|
||||
|
||||
func canPromptClaudeDesktopAPIKey() bool {
|
||||
return isInteractiveSession() && !currentLaunchConfirmPolicy.requireYesMessage
|
||||
}
|
||||
|
||||
func promptValidClaudeDesktopAPIKey(ctx context.Context) (string, error) {
|
||||
key, err := promptClaudeDesktopAPIKeyValue()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := claudeDesktopValidateAPIKey(ctx, key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func promptClaudeDesktopAPIKeyValue() (string, error) {
|
||||
if !canPromptClaudeDesktopAPIKey() {
|
||||
return "", missingClaudeDesktopAPIKeyError()
|
||||
}
|
||||
key, err := claudeDesktopPromptAPIKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return "", missingClaudeDesktopAPIKeyError()
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func missingClaudeDesktopAPIKeyError() error {
|
||||
return fmt.Errorf("OLLAMA_API_KEY is required for Claude Desktop. Create an API key at %s, then re-run with OLLAMA_API_KEY set", claudeDesktopAPIKeyURL)
|
||||
}
|
||||
|
||||
func promptClaudeDesktopAPIKey() (string, error) {
|
||||
fmt.Fprint(os.Stderr, claudeDesktopAPIKeyPrompt())
|
||||
key, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Fprintln(os.Stderr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(key), nil
|
||||
}
|
||||
|
||||
func claudeDesktopAPIKeyPrompt() string {
|
||||
return fmt.Sprintf("Create an Ollama API key at %s\nEnter Ollama API key (input hidden): ", claudeDesktopAPIKeyURL)
|
||||
}
|
||||
|
||||
func readClaudeDesktopGatewayAPIKey(path string) string {
|
||||
cfg, err := readClaudeDesktopJSON(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
key, _ := cfg["inferenceGatewayApiKey"].(string)
|
||||
return strings.TrimSpace(key)
|
||||
}
|
||||
|
||||
func validateClaudeDesktopAPIKey(ctx context.Context, key string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if claudeDesktopAPIKeyHasInvalidHeaderChars(key) {
|
||||
return claudeDesktopAPIKeyVerificationError()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, claudeDesktopGatewayBaseURL+"/v1/models", nil)
|
||||
if err != nil {
|
||||
return claudeDesktopAPIKeyVerificationError()
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := claudeDesktopHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return claudeDesktopAPIKeyVerificationError()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10))
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden:
|
||||
return fmt.Errorf("Ollama API key was rejected; create a valid key at %s", claudeDesktopAPIKeyURL)
|
||||
case resp.StatusCode >= 200 && resp.StatusCode < 300:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("could not verify Ollama API key; ollama.com returned status %d, try again later", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDesktopAPIKeyHasInvalidHeaderChars(key string) bool {
|
||||
return strings.ContainsFunc(key, func(r rune) bool {
|
||||
return r < ' ' || r == 0x7f
|
||||
})
|
||||
}
|
||||
|
||||
func claudeDesktopAPIKeyVerificationError() error {
|
||||
return fmt.Errorf("could not verify Ollama API key; copy a key from %s and try again", claudeDesktopAPIKeyURL)
|
||||
}
|
||||
|
||||
func writeClaudeDesktopDeploymentMode(path, mode string) error {
|
||||
cfg, err := readClaudeDesktopJSONAllowMissing(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Claude Desktop config: %w", err)
|
||||
}
|
||||
cfg["deploymentMode"] = mode
|
||||
return writeClaudeDesktopJSON(path, cfg)
|
||||
}
|
||||
|
||||
func writeClaudeDesktopMeta(path, id, name string) error {
|
||||
meta, err := readClaudeDesktopJSONAllowMissing(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Claude Desktop config metadata: %w", err)
|
||||
}
|
||||
|
||||
meta["appliedId"] = id
|
||||
entries := make([]any, 0)
|
||||
for _, entry := range claudeDesktopAnySlice(meta["entries"]) {
|
||||
entryMap, _ := entry.(map[string]any)
|
||||
if entryMap == nil {
|
||||
entries = append(entries, entry)
|
||||
continue
|
||||
}
|
||||
if entryID, _ := entryMap["id"].(string); entryID == id {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entryMap)
|
||||
}
|
||||
entries = append(entries, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
})
|
||||
meta["entries"] = entries
|
||||
return writeClaudeDesktopJSON(path, meta)
|
||||
}
|
||||
|
||||
func writeClaudeDesktopGatewayProfile(path string, apiKey string, forceChooser bool) error {
|
||||
cfg, err := readClaudeDesktopJSONAllowMissing(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Claude Desktop Ollama profile: %w", err)
|
||||
}
|
||||
cfg["inferenceProvider"] = "gateway"
|
||||
cfg["inferenceGatewayBaseUrl"] = claudeDesktopGatewayBaseURL
|
||||
cfg["inferenceGatewayApiKey"] = apiKey
|
||||
cfg["inferenceGatewayAuthScheme"] = "bearer"
|
||||
delete(cfg, "inferenceModels")
|
||||
cfg["disableDeploymentModeChooser"] = forceChooser
|
||||
return writeClaudeDesktopJSON(path, cfg)
|
||||
}
|
||||
|
||||
func restoreClaudeDesktopMeta(path string) error {
|
||||
meta, err := readClaudeDesktopJSONAllowMissing(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Claude Desktop config metadata: %w", err)
|
||||
}
|
||||
if len(meta) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
changed := false
|
||||
if appliedID, _ := meta["appliedId"].(string); appliedID == claudeDesktopProfileID {
|
||||
delete(meta, "appliedId")
|
||||
changed = true
|
||||
}
|
||||
|
||||
entries := claudeDesktopAnySlice(meta["entries"])
|
||||
if entries != nil {
|
||||
filtered := make([]any, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
entryMap, _ := entry.(map[string]any)
|
||||
if entryID, _ := entryMap["id"].(string); entryID == claudeDesktopProfileID {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
meta["entries"] = filtered
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
return writeClaudeDesktopJSON(path, meta)
|
||||
}
|
||||
|
||||
func restoreClaudeDesktopOllamaProfile(path string) error {
|
||||
cfg, err := readClaudeDesktopJSONAllowMissing(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Claude Desktop Ollama profile: %w", err)
|
||||
}
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
cfg["disableDeploymentModeChooser"] = false
|
||||
delete(cfg, "inferenceProvider")
|
||||
delete(cfg, "inferenceGatewayBaseUrl")
|
||||
delete(cfg, "inferenceGatewayAuthScheme")
|
||||
delete(cfg, "inferenceModels")
|
||||
return writeClaudeDesktopJSON(path, cfg)
|
||||
}
|
||||
|
||||
func readClaudeDesktopAppliedID(path string) string {
|
||||
meta, err := readClaudeDesktopJSON(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
applied, _ := meta["appliedId"].(string)
|
||||
return applied
|
||||
}
|
||||
|
||||
func readClaudeDesktopDeploymentMode(path string) string {
|
||||
cfg, err := readClaudeDesktopJSON(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
mode, _ := cfg["deploymentMode"].(string)
|
||||
return mode
|
||||
}
|
||||
|
||||
func claudeDesktopTargetsConfigured(targets claudeDesktopTargets) bool {
|
||||
if len(targets.normalConfigs) == 0 || len(targets.thirdPartyProfiles) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, path := range targets.normalConfigs {
|
||||
if readClaudeDesktopDeploymentMode(path) != "3p" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
if readClaudeDesktopDeploymentMode(target.desktopConfig) != "3p" {
|
||||
return false
|
||||
}
|
||||
if !claudeDesktopThirdPartyProfileConfigured(target) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func claudeDesktopThirdPartyProfileConfigured(target claudeDesktopThirdPartyPaths) bool {
|
||||
if readClaudeDesktopAppliedID(target.meta) != claudeDesktopProfileID {
|
||||
return false
|
||||
}
|
||||
|
||||
cfg, err := readClaudeDesktopJSON(target.profile)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if s, _ := cfg["inferenceProvider"].(string); s != "gateway" {
|
||||
return false
|
||||
}
|
||||
if s, _ := cfg["inferenceGatewayBaseUrl"].(string); strings.TrimRight(s, "/") != claudeDesktopGatewayBaseURL {
|
||||
return false
|
||||
}
|
||||
if s, _ := cfg["inferenceGatewayApiKey"].(string); strings.TrimSpace(s) == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func readClaudeDesktopJSONAllowMissing(path string) (map[string]any, error) {
|
||||
cfg, err := readClaudeDesktopJSON(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func readClaudeDesktopJSON(path string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{}
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func writeClaudeDesktopJSON(path string, cfg any) error {
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, '\n')
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(path, data)
|
||||
}
|
||||
|
||||
func claudeDesktopAnySlice(value any) []any {
|
||||
switch v := value.(type) {
|
||||
case []any:
|
||||
return v
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func claudeDesktopLaunchOrRestart(prompt string) error {
|
||||
if !claudeDesktopIsRunning() {
|
||||
return claudeDesktopOpenApp()
|
||||
}
|
||||
restartAppPath := ""
|
||||
if claudeDesktopGOOS == "windows" {
|
||||
restartAppPath = claudeDesktopRunningAppPath()
|
||||
}
|
||||
|
||||
restart, err := ConfirmPrompt(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !restart {
|
||||
fmt.Fprintln(os.Stderr, "\nQuit and reopen Claude Desktop when you're ready for the profile change to take effect.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := claudeDesktopQuitApp(); err != nil {
|
||||
return fmt.Errorf("quit Claude Desktop: %w", err)
|
||||
}
|
||||
if err := waitForClaudeDesktopExit(30 * time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
if restartAppPath != "" {
|
||||
return claudeDesktopOpenAppPath(restartAppPath)
|
||||
}
|
||||
return claudeDesktopOpenApp()
|
||||
}
|
||||
|
||||
func waitForClaudeDesktopExit(timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if !claudeDesktopIsRunning() {
|
||||
return nil
|
||||
}
|
||||
claudeDesktopSleep(200 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("Claude Desktop did not quit; quit it manually and re-run the command")
|
||||
}
|
||||
|
||||
func defaultClaudeDesktopIsRunning() bool {
|
||||
switch claudeDesktopGOOS {
|
||||
case "darwin":
|
||||
out, err := exec.Command("pgrep", "-f", "Claude.app/Contents/MacOS/Claude").Output()
|
||||
return err == nil && strings.TrimSpace(string(out)) != ""
|
||||
case "windows":
|
||||
out, err := exec.Command("powershell.exe", "-NoProfile", "-Command", `(Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 } | Select-Object -First 1).Id`).Output()
|
||||
return err == nil && strings.TrimSpace(string(out)) != ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func defaultClaudeDesktopOpenApp() error {
|
||||
switch claudeDesktopGOOS {
|
||||
case "windows":
|
||||
if path := claudeDesktopAppPath(); path != "" {
|
||||
return claudeDesktopOpenAppPath(path)
|
||||
}
|
||||
if path := claudeDesktopRunningAppPath(); path != "" {
|
||||
return claudeDesktopOpenAppPath(path)
|
||||
}
|
||||
return fmt.Errorf("Claude Desktop executable was not found; open Claude Desktop manually once and re-run 'ollama launch claude-desktop --restore'")
|
||||
case "darwin":
|
||||
return openClaudeDesktopDarwin()
|
||||
default:
|
||||
return claudeDesktopSupported()
|
||||
}
|
||||
}
|
||||
|
||||
func defaultClaudeDesktopOpenAppPath(path string) error {
|
||||
switch claudeDesktopGOOS {
|
||||
case "windows":
|
||||
return exec.Command("powershell.exe", "-NoProfile", "-Command", "Start-Process -FilePath "+quotePowerShellString(path)).Run()
|
||||
case "darwin":
|
||||
return openClaudeDesktopDarwin()
|
||||
default:
|
||||
return claudeDesktopSupported()
|
||||
}
|
||||
}
|
||||
|
||||
func openClaudeDesktopDarwin() error {
|
||||
cmd := exec.Command("open", "-a", "Claude")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func defaultClaudeDesktopRunningAppPath() string {
|
||||
if claudeDesktopGOOS != "windows" {
|
||||
return ""
|
||||
}
|
||||
script := `(Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 -and $_.Path } | Select-Object -First 1 -ExpandProperty Path)`
|
||||
out, err := exec.Command("powershell.exe", "-NoProfile", "-Command", script).Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
func defaultClaudeDesktopQuitApp() error {
|
||||
if claudeDesktopGOOS == "windows" {
|
||||
script := `Get-Process claude -ErrorAction SilentlyContinue | Where-Object { $_.MainWindowHandle -ne 0 } | ForEach-Object { [void]$_.CloseMainWindow() }`
|
||||
return exec.Command("powershell.exe", "-NoProfile", "-Command", script).Run()
|
||||
}
|
||||
return exec.Command("osascript", "-e", `tell application "Claude" to quit`).Run()
|
||||
}
|
||||
|
||||
func quotePowerShellString(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
||||
}
|
||||
946
cmd/launch/claude_desktop_test.go
Normal file
946
cmd/launch/claude_desktop_test.go
Normal file
@@ -0,0 +1,946 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func withClaudeDesktopPlatform(t *testing.T, goos string) {
|
||||
t.Helper()
|
||||
old := claudeDesktopGOOS
|
||||
claudeDesktopGOOS = goos
|
||||
t.Cleanup(func() {
|
||||
claudeDesktopGOOS = old
|
||||
})
|
||||
}
|
||||
|
||||
func withClaudeDesktopValidation(t *testing.T, fn func(context.Context, string) error) {
|
||||
t.Helper()
|
||||
old := claudeDesktopValidateAPIKey
|
||||
claudeDesktopValidateAPIKey = fn
|
||||
t.Cleanup(func() {
|
||||
claudeDesktopValidateAPIKey = old
|
||||
})
|
||||
}
|
||||
|
||||
func withClaudeDesktopPrompt(t *testing.T, fn func() (string, error)) {
|
||||
t.Helper()
|
||||
old := claudeDesktopPromptAPIKey
|
||||
claudeDesktopPromptAPIKey = fn
|
||||
t.Cleanup(func() {
|
||||
claudeDesktopPromptAPIKey = old
|
||||
})
|
||||
}
|
||||
|
||||
func withClaudeDesktopProcessHooks(t *testing.T, running func() bool, quit func() error, open func() error) {
|
||||
t.Helper()
|
||||
oldRunning := claudeDesktopIsRunning
|
||||
oldQuit := claudeDesktopQuitApp
|
||||
oldOpen := claudeDesktopOpenApp
|
||||
oldOpenPath := claudeDesktopOpenAppPath
|
||||
oldRunningPath := claudeDesktopRunningAppPath
|
||||
oldSleep := claudeDesktopSleep
|
||||
claudeDesktopIsRunning = running
|
||||
claudeDesktopQuitApp = quit
|
||||
claudeDesktopOpenApp = open
|
||||
claudeDesktopOpenAppPath = oldOpenPath
|
||||
claudeDesktopRunningAppPath = oldRunningPath
|
||||
claudeDesktopSleep = func(time.Duration) {}
|
||||
t.Cleanup(func() {
|
||||
claudeDesktopIsRunning = oldRunning
|
||||
claudeDesktopQuitApp = oldQuit
|
||||
claudeDesktopOpenApp = oldOpen
|
||||
claudeDesktopOpenAppPath = oldOpenPath
|
||||
claudeDesktopRunningAppPath = oldRunningPath
|
||||
claudeDesktopSleep = oldSleep
|
||||
})
|
||||
}
|
||||
|
||||
func claudeDesktopReadJSON(t *testing.T, path string) map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read %s: %v", path, err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("parse %s: %v", path, err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestClaudeDesktopIntegration(t *testing.T) {
|
||||
c := &ClaudeDesktop{}
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
t.Run("implements managed autodiscovery integration", func(t *testing.T) {
|
||||
var _ ManagedAutodiscoveryIntegration = c
|
||||
})
|
||||
t.Run("does not use local Ollama Cloud auth gate", func(t *testing.T) {
|
||||
if _, ok := any(c).(ManagedAutodiscoveryCloudIntegration); ok {
|
||||
t.Fatal("Claude Desktop should validate OLLAMA_API_KEY directly instead of requiring local Ollama Cloud sign-in")
|
||||
}
|
||||
})
|
||||
t.Run("implements restore", func(t *testing.T) {
|
||||
var _ RestorableIntegration = c
|
||||
})
|
||||
t.Run("has restore hint", func(t *testing.T) {
|
||||
var _ RestoreHintIntegration = c
|
||||
if !strings.Contains(c.RestoreHint(), "--restore") {
|
||||
t.Fatalf("expected restore hint to mention --restore, got %q", c.RestoreHint())
|
||||
}
|
||||
if strings.Contains(c.RestoreHint(), "Tip:") {
|
||||
t.Fatalf("restore hint should not use Tip wording, got %q", c.RestoreHint())
|
||||
}
|
||||
})
|
||||
t.Run("has success messages", func(t *testing.T) {
|
||||
var _ ConfigurationSuccessIntegration = c
|
||||
var _ RestoreSuccessIntegration = c
|
||||
if got := c.ConfigurationSuccessMessage(); got != "Claude Desktop profile changed to Ollama Cloud.\nTo restore the usual Claude profile, run: ollama launch claude-desktop --restore" {
|
||||
t.Fatalf("configuration success message = %q", got)
|
||||
}
|
||||
if got := c.RestoreSuccessMessage(); got != "Claude Desktop restored to the usual Claude profile." {
|
||||
t.Fatalf("restore success message = %q", got)
|
||||
}
|
||||
})
|
||||
t.Run("skips local model readiness", func(t *testing.T) {
|
||||
var _ ManagedModelReadinessSkipper = c
|
||||
if !c.SkipModelReadiness() {
|
||||
t.Fatal("expected Claude Desktop to skip local model readiness")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ClaudeDesktopLaunchReturnsUnsupported(t *testing.T) {
|
||||
for _, name := range []string{"claude-desktop", "claude-app"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: name})
|
||||
if err == nil {
|
||||
t.Fatal("expected Claude Desktop launch to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
|
||||
t.Fatalf("expected unsupported guidance, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
|
||||
t.Fatalf("expected restore guidance, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ClaudeDesktopRestoreStillWorks(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(tmpDir, "Applications", "Claude.app"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stderr := captureStderr(t, func() {
|
||||
err = LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "claude-desktop", Restore: true})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("LaunchIntegration restore returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(stderr, claudeDesktopRestoredMessage) {
|
||||
t.Fatalf("expected restore success message, got stderr: %q", stderr)
|
||||
}
|
||||
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
|
||||
if desktopConfig["deploymentMode"] != "1p" {
|
||||
t.Fatalf("deploymentMode = %v, want 1p", desktopConfig["deploymentMode"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureWritesOllamaCloudProfile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
|
||||
var validatedKey string
|
||||
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
|
||||
validatedKey = key
|
||||
return nil
|
||||
})
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.desktopConfig), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.meta), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.desktopConfig, []byte(`{"existing":true}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.meta, []byte(`{"entries":[{"id":"custom","name":"Custom"}]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("Configure returned error: %v", err)
|
||||
}
|
||||
if validatedKey != "test-api-key" {
|
||||
t.Fatalf("validated key = %q, want test API key", validatedKey)
|
||||
}
|
||||
|
||||
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
|
||||
if desktopConfig["existing"] != true {
|
||||
t.Fatalf("existing desktop config key was not preserved: %v", desktopConfig)
|
||||
}
|
||||
if desktopConfig["deploymentMode"] != "3p" {
|
||||
t.Fatalf("deploymentMode = %v, want 3p", desktopConfig["deploymentMode"])
|
||||
}
|
||||
normalConfig := claudeDesktopReadJSON(t, paths.normalConfig)
|
||||
if normalConfig["deploymentMode"] != "3p" {
|
||||
t.Fatalf("normal deploymentMode = %v, want 3p", normalConfig["deploymentMode"])
|
||||
}
|
||||
|
||||
meta := claudeDesktopReadJSON(t, paths.meta)
|
||||
if meta["appliedId"] != claudeDesktopProfileID {
|
||||
t.Fatalf("appliedId = %v, want %s", meta["appliedId"], claudeDesktopProfileID)
|
||||
}
|
||||
entries, _ := meta["entries"].([]any)
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("entries len = %d, want 2: %v", len(entries), entries)
|
||||
}
|
||||
|
||||
profile := claudeDesktopReadJSON(t, paths.profile)
|
||||
if profile["inferenceProvider"] != "gateway" {
|
||||
t.Fatalf("inferenceProvider = %v, want gateway", profile["inferenceProvider"])
|
||||
}
|
||||
if profile["inferenceGatewayBaseUrl"] != claudeDesktopGatewayBaseURL {
|
||||
t.Fatalf("base URL = %v, want %s", profile["inferenceGatewayBaseUrl"], claudeDesktopGatewayBaseURL)
|
||||
}
|
||||
if profile["inferenceGatewayApiKey"] != "test-api-key" {
|
||||
t.Fatal("expected configured API key to be written")
|
||||
}
|
||||
if profile["inferenceGatewayAuthScheme"] != "bearer" {
|
||||
t.Fatalf("auth scheme = %v, want bearer", profile["inferenceGatewayAuthScheme"])
|
||||
}
|
||||
if profile["disableDeploymentModeChooser"] != true {
|
||||
t.Fatalf("disableDeploymentModeChooser = %v, want true", profile["disableDeploymentModeChooser"])
|
||||
}
|
||||
if _, ok := profile["inferenceModels"]; ok {
|
||||
t.Fatalf("inferenceModels should be omitted so Claude can discover models, got %v", profile["inferenceModels"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureAutodiscoveryRemovesExistingModelCatalog(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"inferenceModels":["qwen3.5"],"inferenceGatewayApiKey":"old"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
|
||||
}
|
||||
|
||||
profile := claudeDesktopReadJSON(t, paths.profile)
|
||||
if _, ok := profile["inferenceModels"]; ok {
|
||||
t.Fatalf("inferenceModels should be removed, got %v", profile["inferenceModels"])
|
||||
}
|
||||
if profile["inferenceGatewayApiKey"] != "test-api-key" {
|
||||
t.Fatal("expected env API key to replace the old key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopWindowsConfigPathsUseLocalAppData(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := filepath.Join(tmpDir, "LocalAppData", "Claude-3p", "claude_desktop_config.json"); paths.desktopConfig != want {
|
||||
t.Fatalf("desktop config = %q, want %q", paths.desktopConfig, want)
|
||||
}
|
||||
if want := filepath.Join(tmpDir, "LocalAppData", "Claude", "claude_desktop_config.json"); paths.normalConfig != want {
|
||||
t.Fatalf("normal config = %q, want %q", paths.normalConfig, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopWindowsConfigPathsFallbackToNestProfile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
if err := os.MkdirAll(filepath.Join(local, "Claude Nest-3p"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := filepath.Join(local, "Claude Nest-3p", "claude_desktop_config.json"); paths.desktopConfig != want {
|
||||
t.Fatalf("desktop config = %q, want %q", paths.desktopConfig, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopAutodiscoveryConfiguredOnWindows(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
|
||||
|
||||
c := &ClaudeDesktop{}
|
||||
if err := c.ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("Configure returned error: %v", err)
|
||||
}
|
||||
if !c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected Claude Desktop autodiscovery config to be detected on Windows")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureAutodiscoveryTouchesAllWindowsProfileCandidates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
|
||||
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(targets.normalConfigs) != 2 {
|
||||
t.Fatalf("normal config target count = %d, want 2", len(targets.normalConfigs))
|
||||
}
|
||||
if len(targets.thirdPartyProfiles) != 2 {
|
||||
t.Fatalf("third-party target count = %d, want 2", len(targets.thirdPartyProfiles))
|
||||
}
|
||||
|
||||
c := &ClaudeDesktop{}
|
||||
if err := c.ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
|
||||
}
|
||||
|
||||
for _, path := range targets.normalConfigs {
|
||||
cfg := claudeDesktopReadJSON(t, path)
|
||||
if cfg["deploymentMode"] != "3p" {
|
||||
t.Fatalf("%s deploymentMode = %v, want 3p", path, cfg["deploymentMode"])
|
||||
}
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
cfg := claudeDesktopReadJSON(t, target.desktopConfig)
|
||||
if cfg["deploymentMode"] != "3p" {
|
||||
t.Fatalf("%s deploymentMode = %v, want 3p", target.desktopConfig, cfg["deploymentMode"])
|
||||
}
|
||||
meta := claudeDesktopReadJSON(t, target.meta)
|
||||
if meta["appliedId"] != claudeDesktopProfileID {
|
||||
t.Fatalf("%s appliedId = %v, want %s", target.meta, meta["appliedId"], claudeDesktopProfileID)
|
||||
}
|
||||
profile := claudeDesktopReadJSON(t, target.profile)
|
||||
if profile["inferenceProvider"] != "gateway" {
|
||||
t.Fatalf("%s inferenceProvider = %v, want gateway", target.profile, profile["inferenceProvider"])
|
||||
}
|
||||
if profile["inferenceGatewayBaseUrl"] != claudeDesktopGatewayBaseURL {
|
||||
t.Fatalf("%s base URL = %v, want %s", target.profile, profile["inferenceGatewayBaseUrl"], claudeDesktopGatewayBaseURL)
|
||||
}
|
||||
if profile["inferenceGatewayApiKey"] != "test-api-key" {
|
||||
t.Fatalf("%s should contain the configured API key", target.profile)
|
||||
}
|
||||
if _, ok := profile["inferenceModels"]; ok {
|
||||
t.Fatalf("%s inferenceModels should be omitted, got %v", target.profile, profile["inferenceModels"])
|
||||
}
|
||||
}
|
||||
if !c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected all Windows profile candidates to be considered configured")
|
||||
}
|
||||
|
||||
if err := writeClaudeDesktopDeploymentMode(targets.thirdPartyProfiles[1].desktopConfig, "1p"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected a stale Windows candidate to force reconfiguration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopInstalledOnWindowsRecognizesLocalProfileDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
|
||||
if err := os.MkdirAll(filepath.Join(local, "Claude-3p"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !claudeDesktopInstalled() {
|
||||
t.Fatal("expected Claude Desktop to be installed when the Windows profile directory exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopWindowsAppPathFindsAnthropicClaudeInstall(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
want := filepath.Join(local, "AnthropicClaude", "app-1.2.3", "Claude.exe")
|
||||
if err := os.MkdirAll(filepath.Dir(want), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(want, []byte(""), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := claudeDesktopAppPath(); got != want {
|
||||
t.Fatalf("claudeDesktopAppPath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitForClaudeDesktopExitUsesRunningHook(t *testing.T) {
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
runningChecks := 0
|
||||
withClaudeDesktopProcessHooks(t,
|
||||
func() bool {
|
||||
runningChecks++
|
||||
return runningChecks == 1
|
||||
},
|
||||
func() error { return nil },
|
||||
func() error { return nil },
|
||||
)
|
||||
|
||||
if err := waitForClaudeDesktopExit(time.Second); err != nil {
|
||||
t.Fatalf("waitForClaudeDesktopExit returned error: %v", err)
|
||||
}
|
||||
if runningChecks < 2 {
|
||||
t.Fatalf("expected running hook to be checked until the visible window exits, got %d checks", runningChecks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopWindowsRestoreRestartUsesCapturedDesktopPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
|
||||
restoreConfirm := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
defer restoreConfirm()
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
desktopPath := `C:\Users\parth\AppData\Local\AnthropicClaude\app-1.2.3\Claude.exe`
|
||||
running := true
|
||||
var openedPath string
|
||||
withClaudeDesktopProcessHooks(t,
|
||||
func() bool { return running },
|
||||
func() error {
|
||||
running = false
|
||||
return nil
|
||||
},
|
||||
func() error {
|
||||
t.Fatal("expected restart to open the captured Desktop executable path, not the generic launcher")
|
||||
return nil
|
||||
},
|
||||
)
|
||||
claudeDesktopRunningAppPath = func() string { return desktopPath }
|
||||
claudeDesktopOpenAppPath = func(path string) error {
|
||||
openedPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := (&ClaudeDesktop{}).Restore(); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
if openedPath != desktopPath {
|
||||
t.Fatalf("opened path = %q, want %q", openedPath, desktopPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopWindowsOpenDoesNotFallBackToClaudeCommand(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
t.Setenv("LOCALAPPDATA", filepath.Join(tmpDir, "LocalAppData"))
|
||||
|
||||
oldRunningPath := claudeDesktopRunningAppPath
|
||||
claudeDesktopRunningAppPath = func() string { return "" }
|
||||
t.Cleanup(func() { claudeDesktopRunningAppPath = oldRunningPath })
|
||||
|
||||
err := defaultClaudeDesktopOpenApp()
|
||||
if err == nil || !strings.Contains(err.Error(), "Claude Desktop executable was not found") {
|
||||
t.Fatalf("defaultClaudeDesktopOpenApp error = %v, want executable-not-found error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureStopsBeforeWriteWhenKeyValidationFails(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "bad-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error {
|
||||
return errors.New("invalid key")
|
||||
})
|
||||
|
||||
err := (&ClaudeDesktop{}).ConfigureAutodiscovery()
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid key") {
|
||||
t.Fatalf("Configure error = %v, want invalid key", err)
|
||||
}
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(paths.desktopConfig); !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("desktop config should not be written after validation failure, stat err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClaudeDesktopAPIKeyUsesClaudeModelsRoute(t *testing.T) {
|
||||
oldClient := claudeDesktopHTTPClient
|
||||
var gotPath, gotAuth string
|
||||
claudeDesktopHTTPClient = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotPath = req.URL.Path
|
||||
gotAuth = req.Header.Get("Authorization")
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"data":[]}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
t.Cleanup(func() {
|
||||
claudeDesktopHTTPClient = oldClient
|
||||
})
|
||||
|
||||
if err := validateClaudeDesktopAPIKey(context.Background(), "test-key"); err != nil {
|
||||
t.Fatalf("validateClaudeDesktopAPIKey returned error: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/models" {
|
||||
t.Fatalf("validation path = %q, want /v1/models", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer test-key" {
|
||||
t.Fatalf("Authorization header = %q, want bearer key", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClaudeDesktopAPIKeyHidesInvalidHeaderDetails(t *testing.T) {
|
||||
err := validateClaudeDesktopAPIKey(context.Background(), "bad\nkey")
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for key with newline")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "could not verify Ollama API key") {
|
||||
t.Fatalf("validation error = %v, want friendly verification message", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "invalid header") || strings.Contains(err.Error(), "net/http") {
|
||||
t.Fatalf("validation error should not expose transport internals: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "https://ollama.com/settings/keys") {
|
||||
t.Fatalf("validation error should include settings link: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureRequiresAPIKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error {
|
||||
t.Fatal("validation should not run without an API key")
|
||||
return nil
|
||||
})
|
||||
|
||||
err := (&ClaudeDesktop{}).ConfigureAutodiscovery()
|
||||
if err == nil || !strings.Contains(err.Error(), "OLLAMA_API_KEY is required") {
|
||||
t.Fatalf("Configure error = %v, want missing key guidance", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopAPIKeyPromptIncludesSettingsLink(t *testing.T) {
|
||||
prompt := claudeDesktopAPIKeyPrompt()
|
||||
if !strings.Contains(prompt, "Enter Ollama API key") {
|
||||
t.Fatalf("prompt should ask for the API key, got %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "https://ollama.com/settings/keys") {
|
||||
t.Fatalf("prompt should include API key settings link, got %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureReusesExistingAPIKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "")
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"inferenceGatewayApiKey":"existing-key"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var validatedKey string
|
||||
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
|
||||
validatedKey = key
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
|
||||
}
|
||||
if validatedKey != "existing-key" {
|
||||
t.Fatalf("validated key = %q, want existing-key", validatedKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureReplacesInvalidExistingAPIKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
withInteractiveSession(t, true)
|
||||
t.Setenv("OLLAMA_API_KEY", "")
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"inferenceGatewayApiKey":"stale-key"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var validated []string
|
||||
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
|
||||
validated = append(validated, key)
|
||||
if key == "stale-key" {
|
||||
return errors.New("invalid key")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
withClaudeDesktopPrompt(t, func() (string, error) {
|
||||
return "replacement-key", nil
|
||||
})
|
||||
|
||||
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
|
||||
}
|
||||
if diff := compareStrings(validated, []string{"stale-key", "replacement-key"}); diff != "" {
|
||||
t.Fatalf("validated keys mismatch: %s", diff)
|
||||
}
|
||||
profile := claudeDesktopReadJSON(t, paths.profile)
|
||||
if profile["inferenceGatewayApiKey"] != "replacement-key" {
|
||||
t.Fatalf("configured key = %v, want replacement-key", profile["inferenceGatewayApiKey"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopConfigureReusesExistingAPIKeyFromAnyWindowsProfile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
t.Setenv("OLLAMA_API_KEY", "")
|
||||
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fallbackProfile := targets.thirdPartyProfiles[1].profile
|
||||
if err := os.MkdirAll(filepath.Dir(fallbackProfile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(fallbackProfile, []byte(`{"inferenceGatewayApiKey":"fallback-key"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var validatedKey string
|
||||
withClaudeDesktopValidation(t, func(_ context.Context, key string) error {
|
||||
validatedKey = key
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := (&ClaudeDesktop{}).ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("ConfigureAutodiscovery returned error: %v", err)
|
||||
}
|
||||
if validatedKey != "fallback-key" {
|
||||
t.Fatalf("validated key = %q, want fallback-key", validatedKey)
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
profile := claudeDesktopReadJSON(t, target.profile)
|
||||
if profile["inferenceGatewayApiKey"] != "fallback-key" {
|
||||
t.Fatalf("%s should reuse fallback key, got %v", target.profile, profile["inferenceGatewayApiKey"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopAutodiscoveryConfiguredRequiresAppliedOllamaProfile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
|
||||
|
||||
c := &ClaudeDesktop{}
|
||||
if err := c.ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("Configure returned error: %v", err)
|
||||
}
|
||||
if !c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected Claude Desktop autodiscovery config to be detected")
|
||||
}
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"custom"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected another applied profile to hide Claude Desktop autodiscovery config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopAutodiscoveryConfiguredRequiresAPIKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
t.Setenv("OLLAMA_API_KEY", "test-api-key")
|
||||
withClaudeDesktopValidation(t, func(context.Context, string) error { return nil })
|
||||
|
||||
c := &ClaudeDesktop{}
|
||||
if err := c.ConfigureAutodiscovery(); err != nil {
|
||||
t.Fatalf("Configure returned error: %v", err)
|
||||
}
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
profile := claudeDesktopReadJSON(t, paths.profile)
|
||||
delete(profile, "inferenceGatewayApiKey")
|
||||
data, err := json.Marshal(profile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if c.AutodiscoveryConfigured() {
|
||||
t.Fatal("expected missing gateway API key to force Claude Desktop reconfiguration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopRestoreSwitchesBackToFirstPartyMode(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
|
||||
|
||||
paths, err := claudeDesktopConfigPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(paths.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(paths.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer","inferenceModels":["legacy"]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := (&ClaudeDesktop{}).Restore(); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
|
||||
desktopConfig := claudeDesktopReadJSON(t, paths.desktopConfig)
|
||||
if desktopConfig["deploymentMode"] != "1p" {
|
||||
t.Fatalf("deploymentMode = %v, want 1p", desktopConfig["deploymentMode"])
|
||||
}
|
||||
normalConfig := claudeDesktopReadJSON(t, paths.normalConfig)
|
||||
if normalConfig["deploymentMode"] != "1p" {
|
||||
t.Fatalf("normal deploymentMode = %v, want 1p", normalConfig["deploymentMode"])
|
||||
}
|
||||
profile := claudeDesktopReadJSON(t, paths.profile)
|
||||
if profile["disableDeploymentModeChooser"] != false {
|
||||
t.Fatalf("disableDeploymentModeChooser = %v, want false", profile["disableDeploymentModeChooser"])
|
||||
}
|
||||
if profile["inferenceGatewayApiKey"] != "keep" {
|
||||
t.Fatal("restore should leave existing Ollama profile credentials in place")
|
||||
}
|
||||
for _, key := range []string{"inferenceProvider", "inferenceGatewayBaseUrl", "inferenceGatewayAuthScheme", "inferenceModels"} {
|
||||
if _, ok := profile[key]; ok {
|
||||
t.Fatalf("restore should clear stale %s from the Ollama profile: %v", key, profile)
|
||||
}
|
||||
}
|
||||
meta := claudeDesktopReadJSON(t, paths.meta)
|
||||
if _, ok := meta["appliedId"]; ok {
|
||||
t.Fatalf("restore should clear the applied Ollama third-party profile: %v", meta)
|
||||
}
|
||||
if (&ClaudeDesktop{}).AutodiscoveryConfigured() {
|
||||
t.Fatal("restore should leave Claude Desktop autodiscovery unconfigured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopRestoreTouchesAllWindowsProfileCandidates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withClaudeDesktopPlatform(t, "windows")
|
||||
local := filepath.Join(tmpDir, "LocalAppData")
|
||||
t.Setenv("LOCALAPPDATA", local)
|
||||
withClaudeDesktopProcessHooks(t, func() bool { return false }, func() error { return nil }, func() error { return nil })
|
||||
|
||||
targets, err := claudeDesktopTargetPaths()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(targets.normalConfigs) != 2 {
|
||||
t.Fatalf("normal config target count = %d, want 2", len(targets.normalConfigs))
|
||||
}
|
||||
if len(targets.thirdPartyProfiles) != 2 {
|
||||
t.Fatalf("third-party target count = %d, want 2", len(targets.thirdPartyProfiles))
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
if err := os.MkdirAll(filepath.Dir(target.profile), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(target.meta, []byte(`{"appliedId":"`+claudeDesktopProfileID+`","entries":[{"id":"`+claudeDesktopProfileID+`","name":"Ollama"}]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(target.profile, []byte(`{"disableDeploymentModeChooser":true,"inferenceGatewayApiKey":"keep","inferenceProvider":"gateway","inferenceGatewayBaseUrl":"https://ollama.com","inferenceGatewayAuthScheme":"bearer","inferenceModels":["legacy"]}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := (&ClaudeDesktop{}).Restore(); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
|
||||
for _, path := range targets.normalConfigs {
|
||||
cfg := claudeDesktopReadJSON(t, path)
|
||||
if cfg["deploymentMode"] != "1p" {
|
||||
t.Fatalf("%s deploymentMode = %v, want 1p", path, cfg["deploymentMode"])
|
||||
}
|
||||
}
|
||||
for _, target := range targets.thirdPartyProfiles {
|
||||
cfg := claudeDesktopReadJSON(t, target.desktopConfig)
|
||||
if cfg["deploymentMode"] != "1p" {
|
||||
t.Fatalf("%s deploymentMode = %v, want 1p", target.desktopConfig, cfg["deploymentMode"])
|
||||
}
|
||||
meta := claudeDesktopReadJSON(t, target.meta)
|
||||
if _, ok := meta["appliedId"]; ok {
|
||||
t.Fatalf("%s should not keep the Ollama applied profile: %v", target.meta, meta)
|
||||
}
|
||||
profile := claudeDesktopReadJSON(t, target.profile)
|
||||
if profile["disableDeploymentModeChooser"] != false {
|
||||
t.Fatalf("%s disableDeploymentModeChooser = %v, want false", target.profile, profile["disableDeploymentModeChooser"])
|
||||
}
|
||||
if profile["inferenceGatewayApiKey"] != "keep" {
|
||||
t.Fatalf("%s should preserve gateway API key", target.profile)
|
||||
}
|
||||
for _, key := range []string{"inferenceProvider", "inferenceGatewayBaseUrl", "inferenceGatewayAuthScheme", "inferenceModels"} {
|
||||
if _, ok := profile[key]; ok {
|
||||
t.Fatalf("%s should clear stale %s: %v", target.profile, key, profile)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeDesktopRunReturnsUnsupported(t *testing.T) {
|
||||
withClaudeDesktopPlatform(t, "darwin")
|
||||
|
||||
withClaudeDesktopProcessHooks(t,
|
||||
func() bool {
|
||||
t.Fatal("Run should not inspect Claude Desktop process state")
|
||||
return false
|
||||
},
|
||||
func() error {
|
||||
t.Fatal("Run should not quit Claude Desktop")
|
||||
return nil
|
||||
},
|
||||
func() error {
|
||||
t.Fatal("Run should not open Claude Desktop")
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
for _, args := range [][]string{nil, {"--foo"}} {
|
||||
err := (&ClaudeDesktop{}).Run("qwen3.5", nil, args)
|
||||
if err == nil {
|
||||
t.Fatal("expected Run to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
|
||||
t.Fatalf("expected unsupported guidance, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
|
||||
t.Fatalf("expected restore guidance, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
171
cmd/launch/claude_test.go
Normal file
171
cmd/launch/claude_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeIntegration(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Claude Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Claude Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelEnvVars(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("maps all Claude model env vars to the provided model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("llama3.2"))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
|
||||
t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
|
||||
t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty for local models", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supports empty model", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars(""))
|
||||
if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "" {
|
||||
t.Errorf("OPUS = %q, want empty", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "" {
|
||||
t.Errorf("SONNET = %q, want empty", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
|
||||
}
|
||||
if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "" {
|
||||
t.Errorf("HAIKU = %q, want empty", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "" {
|
||||
t.Errorf("SUBAGENT = %q, want empty", got["CLAUDE_CODE_SUBAGENT_MODEL"])
|
||||
}
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets auto compact window for known cloud models", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("glm-5:cloud"))
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "202752" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want 202752", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not set auto compact window for unknown cloud models", func(t *testing.T) {
|
||||
got := envMap(c.modelEnvVars("unknown-model:cloud"))
|
||||
if got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"] != "" {
|
||||
t.Errorf("AUTO_COMPACT_WINDOW = %q, want empty", got["CLAUDE_CODE_AUTO_COMPACT_WINDOW"])
|
||||
}
|
||||
})
|
||||
}
|
||||
104
cmd/launch/cline.go
Normal file
104
cmd/launch/cline.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Cline implements Runner and Editor for the Cline CLI integration
|
||||
type Cline struct{}
|
||||
|
||||
func (c *Cline) String() string { return "Cline" }
|
||||
|
||||
func (c *Cline) Run(model string, _ []LaunchModel, args []string) error {
|
||||
if _, err := exec.LookPath("cline"); err != nil {
|
||||
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||
}
|
||||
|
||||
cmd := exec.Command("cline", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (c *Cline) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cline) Edit(models []LaunchModel) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Ollama as the provider for both act and plan modes
|
||||
baseURL := envconfig.Host().String()
|
||||
config["ollamaBaseUrl"] = baseURL
|
||||
config["actModeApiProvider"] = "ollama"
|
||||
config["actModeOllamaModelId"] = models[0].Name
|
||||
config["actModeOllamaBaseUrl"] = baseURL
|
||||
config["planModeApiProvider"] = "ollama"
|
||||
config["planModeOllamaModelId"] = models[0].Name
|
||||
config["planModeOllamaBaseUrl"] = baseURL
|
||||
|
||||
config["welcomeViewCompleted"] = true
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, data, "cline")
|
||||
}
|
||||
|
||||
func (c *Cline) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{modelID}
|
||||
}
|
||||
204
cmd/launch/cline_test.go
Normal file
204
cmd/launch/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClineIntegration(t *testing.T) {
|
||||
c := &Cline{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Cline" {
|
||||
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineEdit(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
readConfig := func() map[string]any {
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var config map[string]any
|
||||
json.Unmarshal(data, &config)
|
||||
return config
|
||||
}
|
||||
|
||||
t.Run("creates config from scratch", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeApiProvider"] != "ollama" {
|
||||
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeApiProvider"] != "ollama" {
|
||||
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
if config["welcomeViewCompleted"] != true {
|
||||
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing fields", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
existing := map[string]any{
|
||||
"remoteRulesToggles": map[string]any{},
|
||||
"remoteWorkflowToggles": map[string]any{},
|
||||
"customSetting": "keep-me",
|
||||
}
|
||||
data, _ := json.Marshal(existing)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if err := c.Edit(testLaunchModels("glm-5:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["customSetting"] != "keep-me" {
|
||||
t.Errorf("customSetting was not preserved")
|
||||
}
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.Edit(testLaunchModels("glm-5:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||
}
|
||||
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Error("expected no config file to be created for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses first model as primary", func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||
|
||||
if err := c.Edit(testLaunchModels("kimi-k2.5:cloud", "glm-5:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := readConfig()
|
||||
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClineModels(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
|
||||
t.Run("returns nil when no config", func(t *testing.T) {
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "anthropic",
|
||||
"actModeOllamaModelId": "some-model",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
config := map[string]any{
|
||||
"actModeApiProvider": "ollama",
|
||||
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||
}
|
||||
data, _ := json.Marshal(config)
|
||||
os.WriteFile(configPath, data, 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClinePaths(t *testing.T) {
|
||||
c := &Cline{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
configPath := filepath.Join(configDir, "globalState.json")
|
||||
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 || paths[0] != configPath {
|
||||
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
671
cmd/launch/codex.go
Normal file
671
cmd/launch/codex.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// Codex implements Runner for Codex integration
|
||||
type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
const (
|
||||
codexProfileName = "ollama-launch"
|
||||
codexProviderName = "Ollama"
|
||||
codexFallbackContextWindow = 128_000
|
||||
|
||||
codexRootProfileKey = "profile"
|
||||
codexRootModelKey = "model"
|
||||
codexRootModelProviderKey = "model_provider"
|
||||
codexRootModelCatalogJSONKey = "model_catalog_json"
|
||||
)
|
||||
|
||||
func (c *Codex) args(model, modelCatalogPath string, extra []string) []string {
|
||||
args := []string{"--profile", codexProfileName}
|
||||
if modelCatalogPath != "" {
|
||||
args = append(args, "-c", fmt.Sprintf("%s=%q", codexRootModelCatalogJSONKey, modelCatalogPath))
|
||||
}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string, models []LaunchModel, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ensureCodexConfig(model, models); err != nil {
|
||||
return fmt.Errorf("failed to configure codex: %w", err)
|
||||
}
|
||||
|
||||
catalogPath, err := codexModelCatalogPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure codex: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model, catalogPath, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ensureCodexConfig writes a Codex profile and model catalog so Codex uses the
|
||||
// local Ollama server and has model metadata available.
|
||||
func ensureCodexConfig(modelName string, models []LaunchModel) error {
|
||||
configPath, err := codexConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
codexDir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(codexDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
catalogPath := codexModelCatalogPathForConfig(configPath)
|
||||
if err := writeCodexModelCatalog(catalogPath, codexCatalogModel(modelName, models)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeCodexProfile(configPath, catalogPath)
|
||||
}
|
||||
|
||||
func codexConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".codex", "config.toml"), nil
|
||||
}
|
||||
|
||||
func codexModelCatalogPath() (string, error) {
|
||||
configPath, err := codexConfigPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return codexModelCatalogPathForConfig(configPath), nil
|
||||
}
|
||||
|
||||
func codexModelCatalogPathForConfig(configPath string) string {
|
||||
return filepath.Join(filepath.Dir(configPath), "model.json")
|
||||
}
|
||||
|
||||
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
|
||||
// and model provider sections with the correct base URL.
|
||||
func writeCodexProfile(configPath string, modelCatalogPath ...string) error {
|
||||
opts := codexLaunchProfileOptions{
|
||||
forceAPIAuth: true,
|
||||
}
|
||||
if len(modelCatalogPath) > 0 {
|
||||
opts.modelCatalogPath = modelCatalogPath[0]
|
||||
}
|
||||
return writeCodexLaunchProfile(configPath, opts)
|
||||
}
|
||||
|
||||
type codexLaunchProfileOptions struct {
|
||||
activate bool
|
||||
profileName string
|
||||
forceAPIAuth bool
|
||||
setRootModelConfig bool
|
||||
model string
|
||||
modelCatalogPath string
|
||||
backupIntegration string
|
||||
}
|
||||
|
||||
func writeCodexLaunchProfile(configPath string, opts codexLaunchProfileOptions) error {
|
||||
baseURL := codexBaseURL()
|
||||
profileName := codexLaunchProfileName(opts)
|
||||
profileHeader := codexProfileHeaderFor(profileName)
|
||||
providerHeader := codexProviderHeaderFor(profileName)
|
||||
|
||||
content, readErr := os.ReadFile(configPath)
|
||||
text := ""
|
||||
if readErr == nil {
|
||||
text = string(content)
|
||||
} else if !os.IsNotExist(readErr) {
|
||||
return readErr
|
||||
}
|
||||
parsed, err := codexParseConfig(text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(opts.model)
|
||||
if model == "" {
|
||||
model = parsed.ProfileString(profileName, codexRootModelKey)
|
||||
}
|
||||
modelCatalogPath := strings.TrimSpace(opts.modelCatalogPath)
|
||||
if modelCatalogPath == "" {
|
||||
modelCatalogPath = parsed.ProfileString(profileName, codexRootModelCatalogJSONKey)
|
||||
}
|
||||
|
||||
profileLines := []string{}
|
||||
if model != "" {
|
||||
profileLines = append(profileLines, fmt.Sprintf("%s = %q", codexRootModelKey, model))
|
||||
}
|
||||
profileLines = append(profileLines,
|
||||
fmt.Sprintf("openai_base_url = %q", baseURL),
|
||||
fmt.Sprintf("%s = %q", codexRootModelProviderKey, profileName),
|
||||
)
|
||||
if opts.forceAPIAuth {
|
||||
profileLines = append(profileLines, `forced_login_method = "api"`)
|
||||
}
|
||||
if modelCatalogPath != "" {
|
||||
profileLines = append(profileLines, fmt.Sprintf("%s = %q", codexRootModelCatalogJSONKey, modelCatalogPath))
|
||||
}
|
||||
|
||||
sections := []struct {
|
||||
header string
|
||||
lines []string
|
||||
}{
|
||||
{
|
||||
header: profileHeader,
|
||||
lines: profileLines,
|
||||
},
|
||||
{
|
||||
header: providerHeader,
|
||||
lines: []string{
|
||||
fmt.Sprintf("name = %q", codexProviderName),
|
||||
fmt.Sprintf("base_url = %q", baseURL),
|
||||
`wire_api = "responses"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if opts.activate {
|
||||
text = codexSetRootStringValue(text, codexRootProfileKey, profileName)
|
||||
}
|
||||
if opts.setRootModelConfig {
|
||||
if model != "" {
|
||||
text = codexSetRootStringValue(text, codexRootModelKey, model)
|
||||
}
|
||||
text = codexSetRootStringValue(text, codexRootModelProviderKey, profileName)
|
||||
if modelCatalogPath != "" {
|
||||
text = codexSetRootStringValue(text, codexRootModelCatalogJSONKey, modelCatalogPath)
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range sections {
|
||||
text = codexUpsertSection(text, s.header, s.lines)
|
||||
}
|
||||
parsed, err = codexParseConfig(text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := codexValidateLaunchProfileText(parsed, profileName, opts, model, modelCatalogPath, baseURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, []byte(text), opts.backupIntegration)
|
||||
}
|
||||
|
||||
func codexLaunchProfileName(opts codexLaunchProfileOptions) string {
|
||||
if name := strings.TrimSpace(opts.profileName); name != "" {
|
||||
return name
|
||||
}
|
||||
return codexProfileName
|
||||
}
|
||||
|
||||
func codexBaseURL() string {
|
||||
return strings.TrimRight(envconfig.ConnectableHost().String(), "/") + "/v1/"
|
||||
}
|
||||
|
||||
func codexProfileHeader() string {
|
||||
return codexProfileHeaderFor(codexProfileName)
|
||||
}
|
||||
|
||||
func codexProviderHeader() string {
|
||||
return codexProviderHeaderFor(codexProfileName)
|
||||
}
|
||||
|
||||
func codexProfileHeaderFor(profileName string) string {
|
||||
return fmt.Sprintf("[profiles.%s]", profileName)
|
||||
}
|
||||
|
||||
func codexProviderHeaderFor(profileName string) string {
|
||||
return fmt.Sprintf("[model_providers.%s]", profileName)
|
||||
}
|
||||
|
||||
func codexValidateLaunchProfileText(config codexParsedConfig, profileName string, opts codexLaunchProfileOptions, model, modelCatalogPath, baseURL string) error {
|
||||
for _, check := range []struct {
|
||||
path []string
|
||||
want string
|
||||
}{
|
||||
{[]string{"profiles", profileName, "openai_base_url"}, baseURL},
|
||||
{[]string{"profiles", profileName, codexRootModelProviderKey}, profileName},
|
||||
{[]string{"model_providers", profileName, "name"}, codexProviderName},
|
||||
{[]string{"model_providers", profileName, "base_url"}, baseURL},
|
||||
{[]string{"model_providers", profileName, "wire_api"}, "responses"},
|
||||
} {
|
||||
if got, ok := config.String(check.path...); !ok || got != check.want {
|
||||
return fmt.Errorf("generated Codex config missing %s = %q", strings.Join(check.path, "."), check.want)
|
||||
}
|
||||
}
|
||||
if opts.forceAPIAuth {
|
||||
if got, ok := config.String("profiles", profileName, "forced_login_method"); !ok || got != "api" {
|
||||
return fmt.Errorf("generated Codex config missing profiles.%s.forced_login_method = %q", profileName, "api")
|
||||
}
|
||||
}
|
||||
if model != "" {
|
||||
if got, ok := config.String("profiles", profileName, codexRootModelKey); !ok || got != model {
|
||||
return fmt.Errorf("generated Codex config missing profiles.%s.model = %q", profileName, model)
|
||||
}
|
||||
}
|
||||
if modelCatalogPath != "" {
|
||||
if got, ok := config.String("profiles", profileName, codexRootModelCatalogJSONKey); !ok || got != modelCatalogPath {
|
||||
return fmt.Errorf("generated Codex config missing profiles.%s.model_catalog_json = %q", profileName, modelCatalogPath)
|
||||
}
|
||||
}
|
||||
if opts.activate {
|
||||
if got := config.RootString(codexRootProfileKey); got != profileName {
|
||||
return fmt.Errorf("generated Codex config missing profile = %q", profileName)
|
||||
}
|
||||
}
|
||||
if opts.setRootModelConfig {
|
||||
if model != "" {
|
||||
if got := config.RootString(codexRootModelKey); got != model {
|
||||
return fmt.Errorf("generated Codex config missing model = %q", model)
|
||||
}
|
||||
}
|
||||
if got := config.RootString(codexRootModelProviderKey); got != profileName {
|
||||
return fmt.Errorf("generated Codex config missing model_provider = %q", profileName)
|
||||
}
|
||||
if modelCatalogPath != "" {
|
||||
if got := config.RootString(codexRootModelCatalogJSONKey); got != modelCatalogPath {
|
||||
return fmt.Errorf("generated Codex config missing model_catalog_json = %q", modelCatalogPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func codexUpsertSection(text, header string, lines []string) string {
|
||||
block := strings.Join(append([]string{header}, lines...), "\n") + "\n"
|
||||
|
||||
if targetPath, ok := codexTableHeaderPath(header); ok {
|
||||
if start, end, found := codexSectionRange(text, targetPath); found {
|
||||
return text[:start] + block + text[end:]
|
||||
}
|
||||
}
|
||||
|
||||
if text != "" && !strings.HasSuffix(text, "\n") {
|
||||
text += "\n"
|
||||
}
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
}
|
||||
return text + block
|
||||
}
|
||||
|
||||
func codexRemoveSection(text, header string) string {
|
||||
targetPath, ok := codexTableHeaderPath(header)
|
||||
if !ok {
|
||||
return text
|
||||
}
|
||||
start, end, found := codexSectionRange(text, targetPath)
|
||||
if !found {
|
||||
return text
|
||||
}
|
||||
return text[:start] + text[end:]
|
||||
}
|
||||
|
||||
type codexParsedConfig struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
func (c codexParsedConfig) String(path ...string) (string, bool) {
|
||||
if len(path) == 0 {
|
||||
return "", false
|
||||
}
|
||||
var current any = c.values
|
||||
for _, part := range path {
|
||||
table, ok := current.(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
current, ok = table[part]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
value, ok := current.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c codexParsedConfig) RootString(key string) string {
|
||||
value, _ := c.RootStringOK(key)
|
||||
return value
|
||||
}
|
||||
|
||||
func (c codexParsedConfig) RootStringOK(key string) (string, bool) {
|
||||
return c.String(key)
|
||||
}
|
||||
|
||||
func (c codexParsedConfig) ProfileString(profileName, key string) string {
|
||||
value, _ := c.String("profiles", profileName, key)
|
||||
return value
|
||||
}
|
||||
|
||||
func (c codexParsedConfig) ProviderString(profileName, key string) string {
|
||||
value, _ := c.String("model_providers", profileName, key)
|
||||
return value
|
||||
}
|
||||
|
||||
func codexRootStringValue(text, key string) string {
|
||||
config, err := codexParseConfig(text)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return config.RootString(key)
|
||||
}
|
||||
|
||||
func codexRootStringValueOK(text, key string) (string, bool) {
|
||||
config, err := codexParseConfig(text)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return config.RootStringOK(key)
|
||||
}
|
||||
|
||||
func codexStringValue(text string, path ...string) (string, bool) {
|
||||
config, err := codexParseConfig(text)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return config.String(path...)
|
||||
}
|
||||
|
||||
func codexSectionStringValue(text, header, key string) string {
|
||||
path, ok := codexTableHeaderPath(header)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
value, _ := codexStringValue(text, append(path, key)...)
|
||||
return value
|
||||
}
|
||||
|
||||
func codexParseConfig(text string) (codexParsedConfig, error) {
|
||||
values, err := codexParseConfigText(text)
|
||||
if err != nil {
|
||||
return codexParsedConfig{}, err
|
||||
}
|
||||
return codexParsedConfig{values: values}, nil
|
||||
}
|
||||
|
||||
func codexParseConfigText(text string) (map[string]any, error) {
|
||||
cfg := map[string]any{}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return cfg, nil
|
||||
}
|
||||
if err := toml.Unmarshal([]byte(text), &cfg); err != nil {
|
||||
return nil, fmt.Errorf("invalid Codex config TOML: %w", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func codexValidateConfigText(text string) error {
|
||||
_, err := codexParseConfig(text)
|
||||
return err
|
||||
}
|
||||
|
||||
func codexSectionRange(text string, targetPath []string) (int, int, bool) {
|
||||
lines := strings.SplitAfter(text, "\n")
|
||||
offset := 0
|
||||
start := -1
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "#") {
|
||||
offset += len(line)
|
||||
continue
|
||||
}
|
||||
if start >= 0 {
|
||||
return start, offset, true
|
||||
}
|
||||
if path, ok := codexTableHeaderPath(trimmed); ok && codexSamePath(path, targetPath) {
|
||||
start = offset
|
||||
}
|
||||
offset += len(line)
|
||||
}
|
||||
if start >= 0 {
|
||||
return start, len(text), true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func codexTableHeaderPath(header string) ([]string, bool) {
|
||||
trimmed := strings.TrimSpace(header)
|
||||
if !strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "[[") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
const probeKey = "__ollama_launch_probe"
|
||||
cfg := map[string]any{}
|
||||
if err := toml.Unmarshal([]byte(trimmed+"\n"+probeKey+" = true\n"), &cfg); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return codexFindProbePath(cfg, probeKey, nil)
|
||||
}
|
||||
|
||||
func codexFindProbePath(value any, probeKey string, path []string) ([]string, bool) {
|
||||
table, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if probe, ok := table[probeKey].(bool); ok && probe {
|
||||
return path, true
|
||||
}
|
||||
for key, child := range table {
|
||||
if key == probeKey {
|
||||
continue
|
||||
}
|
||||
if childPath, ok := codexFindProbePath(child, probeKey, append(path, key)); ok {
|
||||
return childPath, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func codexSamePath(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func codexSetRootStringValue(text, key, value string) string {
|
||||
lines := strings.SplitAfter(text, "\n")
|
||||
rootEnd := len(lines)
|
||||
for i, line := range lines {
|
||||
if strings.HasPrefix(strings.TrimSpace(line), "[") {
|
||||
rootEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assignment := fmt.Sprintf("%s = %q", key, value)
|
||||
for i := range rootEnd {
|
||||
line := lines[i]
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
if codexRootLineHasKey(trimmed, key) {
|
||||
if strings.HasSuffix(line, "\n") {
|
||||
lines[i] = assignment + "\n"
|
||||
} else {
|
||||
lines[i] = assignment
|
||||
}
|
||||
return strings.Join(lines, "")
|
||||
}
|
||||
}
|
||||
|
||||
insert := assignment + "\n"
|
||||
root := strings.Join(lines[:rootEnd], "")
|
||||
rest := strings.Join(lines[rootEnd:], "")
|
||||
if root != "" && !strings.HasSuffix(root, "\n") {
|
||||
root += "\n"
|
||||
}
|
||||
if rest != "" && !strings.HasSuffix(insert, "\n\n") {
|
||||
insert += "\n"
|
||||
}
|
||||
return root + insert + rest
|
||||
}
|
||||
|
||||
func codexRemoveRootValue(text, key string) string {
|
||||
lines := strings.SplitAfter(text, "\n")
|
||||
rootEnd := len(lines)
|
||||
for i, line := range lines {
|
||||
if strings.HasPrefix(strings.TrimSpace(line), "[") {
|
||||
rootEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(lines))
|
||||
for i, line := range lines {
|
||||
if i < rootEnd {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" && !strings.HasPrefix(trimmed, "#") && codexRootLineHasKey(trimmed, key) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, line)
|
||||
}
|
||||
return strings.Join(out, "")
|
||||
}
|
||||
|
||||
func codexRootLineHasKey(line, key string) bool {
|
||||
cfg := map[string]any{}
|
||||
if err := toml.Unmarshal([]byte(line+"\n"), &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
_, ok := cfg[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func codexCatalogModel(modelName string, models []LaunchModel) LaunchModel {
|
||||
if model, ok := findLaunchModel(models, modelName); ok {
|
||||
return model.WithCloudLimits()
|
||||
}
|
||||
return fallbackLaunchModel(modelName)
|
||||
}
|
||||
|
||||
func writeCodexModelCatalog(catalogPath string, model LaunchModel) error {
|
||||
entry := buildCodexModelEntry(model)
|
||||
|
||||
catalog := map[string]any{
|
||||
"models": []any{entry},
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(catalog, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(catalogPath, data, 0o644)
|
||||
}
|
||||
|
||||
func buildCodexModelEntry(launchModel LaunchModel) map[string]any {
|
||||
modelName := launchModel.Name
|
||||
contextWindow := codexFallbackContextWindow
|
||||
systemPrompt := ""
|
||||
|
||||
if launchModel.ContextLength > 0 {
|
||||
contextWindow = launchModel.ContextLength
|
||||
} else if launchModel.Details.ContextLength > 0 {
|
||||
contextWindow = launchModel.Details.ContextLength
|
||||
}
|
||||
if l, ok := lookupCloudModelLimit(modelName); ok {
|
||||
contextWindow = l.Context
|
||||
}
|
||||
|
||||
if !isCloudModelName(modelName) && launchModel.Details.Format != "safetensors" {
|
||||
if ctxLen := envconfig.ContextLength(); ctxLen > 0 {
|
||||
contextWindow = int(ctxLen)
|
||||
}
|
||||
}
|
||||
|
||||
modalities := []string{"text"}
|
||||
if launchModel.HasCapability(model.CapabilityVision) {
|
||||
modalities = append(modalities, "image")
|
||||
}
|
||||
|
||||
truncationMode := "bytes"
|
||||
if isCloudModelName(modelName) {
|
||||
truncationMode = "tokens"
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"slug": modelName,
|
||||
"display_name": modelName,
|
||||
"context_window": contextWindow,
|
||||
"shell_type": "default",
|
||||
"visibility": "list",
|
||||
"supported_in_api": true,
|
||||
"priority": 0,
|
||||
"truncation_policy": map[string]any{"mode": truncationMode, "limit": 10000},
|
||||
"input_modalities": modalities,
|
||||
"base_instructions": systemPrompt,
|
||||
"support_verbosity": true,
|
||||
"default_verbosity": "low",
|
||||
"supports_parallel_tool_calls": false,
|
||||
"supports_reasoning_summaries": false,
|
||||
"supported_reasoning_levels": []any{},
|
||||
"experimental_supported_tools": []any{},
|
||||
}
|
||||
}
|
||||
|
||||
func checkCodexVersion() error {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
|
||||
}
|
||||
|
||||
out, err := exec.Command("codex", "--version").Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get codex version: %w", err)
|
||||
}
|
||||
|
||||
// Parse output like "codex-cli 0.87.0"
|
||||
fields := strings.Fields(strings.TrimSpace(string(out)))
|
||||
if len(fields) < 2 {
|
||||
return fmt.Errorf("unexpected codex version output: %s", string(out))
|
||||
}
|
||||
|
||||
version := "v" + fields[len(fields)-1]
|
||||
minVersion := "v0.81.0"
|
||||
|
||||
if semver.Compare(version, minVersion) < 0 {
|
||||
return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1042
cmd/launch/codex_app.go
Normal file
1042
cmd/launch/codex_app.go
Normal file
File diff suppressed because it is too large
Load Diff
1436
cmd/launch/codex_app_test.go
Normal file
1436
cmd/launch/codex_app_test.go
Normal file
File diff suppressed because it is too large
Load Diff
595
cmd/launch/codex_test.go
Normal file
595
cmd/launch/codex_test.go
Normal file
@@ -0,0 +1,595 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
modelpkg "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestCodexArgs(t *testing.T) {
|
||||
c := &Codex{}
|
||||
catalogPath := filepath.Join("tmp", "model.json")
|
||||
catalogArg := fmt.Sprintf("%s=%q", codexRootModelCatalogJSONKey, catalogPath)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--profile", "ollama-launch", "-c", catalogArg}},
|
||||
{"with model and extra args", "qwen3.5", []string{"-p", "myprofile"}, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "qwen3.5", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--profile", "ollama-launch", "-c", catalogArg, "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, catalogPath, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteCodexProfile(t *testing.T) {
|
||||
t.Run("creates new file when none exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||
t.Error("missing [profiles.ollama-launch] header")
|
||||
}
|
||||
if !strings.Contains(content, "openai_base_url") {
|
||||
t.Error("missing openai_base_url key")
|
||||
}
|
||||
if !strings.Contains(content, "/v1/") {
|
||||
t.Error("missing /v1/ suffix in base URL")
|
||||
}
|
||||
if !strings.Contains(content, `forced_login_method = "api"`) {
|
||||
t.Error("missing forced_login_method key")
|
||||
}
|
||||
if !strings.Contains(content, `model_provider = "ollama-launch"`) {
|
||||
t.Error("missing model_provider key")
|
||||
}
|
||||
if !strings.Contains(content, fmt.Sprintf("model_catalog_json = %q", catalogPath)) {
|
||||
t.Error("missing model_catalog_json key")
|
||||
}
|
||||
if !strings.Contains(content, "[model_providers.ollama-launch]") {
|
||||
t.Error("missing [model_providers.ollama-launch] section")
|
||||
}
|
||||
if !strings.Contains(content, `name = "Ollama"`) {
|
||||
t.Error("missing model provider name")
|
||||
}
|
||||
if err := codexValidateConfigText(content); err != nil {
|
||||
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("appends profile to existing file without profile", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
existing := "[some_other_section]\nkey = \"value\"\n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if !strings.Contains(content, "[some_other_section]") {
|
||||
t.Error("existing section was removed")
|
||||
}
|
||||
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||
t.Error("missing [profiles.ollama-launch] header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces existing profile section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if strings.Contains(content, "old:1234") {
|
||||
t.Error("old URL was not replaced")
|
||||
}
|
||||
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
|
||||
t.Errorf("expected exactly one [profiles.ollama-launch] section, got %d", strings.Count(content, "[profiles.ollama-launch]"))
|
||||
}
|
||||
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
|
||||
t.Errorf("expected exactly one [model_providers.ollama-launch] section, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
|
||||
}
|
||||
if err := codexValidateConfigText(content); err != nil {
|
||||
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces equivalent quoted profile table", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
existing := "" +
|
||||
`profile = "default"` + "\n\n" +
|
||||
`[profiles."ollama-launch"]` + "\n" +
|
||||
`openai_base_url = "http://old:1234/v1/"` + "\n\n" +
|
||||
`[model_providers."ollama-launch"]` + "\n" +
|
||||
`name = "Old"` + "\n" +
|
||||
`base_url = "http://old:1234/v1/"` + "\n\n" +
|
||||
`[profiles.default]` + "\n" +
|
||||
`model = "gpt-5.5"` + "\n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
if err := writeCodexProfile(configPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if strings.Contains(content, `profiles."ollama-launch"`) {
|
||||
t.Fatalf("quoted profile table should be replaced, got:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "old:1234") {
|
||||
t.Fatalf("old URL was not replaced, got:\n%s", content)
|
||||
}
|
||||
if got := codexSectionStringValue(content, codexProfileHeader(), "model_provider"); got != codexProfileName {
|
||||
t.Fatalf("profile model_provider = %q, want %q", got, codexProfileName)
|
||||
}
|
||||
if got := codexSectionStringValue(content, codexProviderHeader(), "base_url"); !strings.Contains(got, "/v1/") {
|
||||
t.Fatalf("provider base_url = %q, want /v1/ URL", got)
|
||||
}
|
||||
if err := codexValidateConfigText(content); err != nil {
|
||||
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects invalid existing toml without writing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
existing := "profile = \n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
err := writeCodexProfile(configPath)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid Codex config TOML") {
|
||||
t.Fatalf("writeCodexProfile error = %v, want invalid TOML", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != existing {
|
||||
t.Fatalf("invalid config should be left untouched, got:\n%s", data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects malformed existing toml variants without writing", func(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
"duplicate root key": "profile = \"default\"\nprofile = \"other\"\n",
|
||||
"unterminated string": "model = \"gpt-5.5\n",
|
||||
"bad table": "[profiles.ollama-launch\nmodel = \"llama3.2\"\n",
|
||||
"duplicate table key": "[profiles.ollama-launch]\nmodel = \"a\"\nmodel = \"b\"\n",
|
||||
}
|
||||
for name, existing := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
if err := os.WriteFile(configPath, []byte(existing), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := writeCodexProfile(configPath)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid Codex config TOML") {
|
||||
t.Fatalf("writeCodexProfile error = %v, want invalid TOML", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != existing {
|
||||
t.Fatalf("invalid config should be left untouched, got:\n%s", data)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("backs up previous config before overwrite", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
existing := "# original-codex-backup-marker\n[profiles.default]\nmodel = \"gpt-5.5\"\n"
|
||||
if err := os.WriteFile(configPath, []byte(existing), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := writeCodexProfile(configPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertBackupContains(t, filepath.Join(fileutil.BackupDir(), "config.toml.*"), "original-codex-backup-marker")
|
||||
})
|
||||
|
||||
t.Run("updates equivalent quoted root keys", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
existing := "" +
|
||||
`"profile" = "default"` + "\n" +
|
||||
`"model" = "gpt-5.5"` + "\n" +
|
||||
`"model_provider" = "openai"` + "\n\n" +
|
||||
`[profiles.default]` + "\n" +
|
||||
`model = "gpt-5.5"` + "\n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
err := writeCodexLaunchProfile(configPath, codexLaunchProfileOptions{
|
||||
activate: true,
|
||||
setRootModelConfig: true,
|
||||
model: "llama3.2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
for key, want := range map[string]string{
|
||||
"profile": codexProfileName,
|
||||
"model": "llama3.2",
|
||||
"model_provider": codexProfileName,
|
||||
} {
|
||||
if got := codexRootStringValue(content, key); got != want {
|
||||
t.Fatalf("root %s = %q, want %q in:\n%s", key, got, want, content)
|
||||
}
|
||||
}
|
||||
if strings.Contains(content, `"profile"`) || strings.Contains(content, `"model_provider"`) {
|
||||
t.Fatalf("quoted root keys should be rewritten once, got:\n%s", content)
|
||||
}
|
||||
if err := codexValidateConfigText(content); err != nil {
|
||||
t.Fatalf("generated config should be valid TOML: %v\n%s", err, content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replaces profile while preserving following sections", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if strings.Contains(content, "old:1234") {
|
||||
t.Error("old URL was not replaced")
|
||||
}
|
||||
if !strings.Contains(content, "[another_section]") {
|
||||
t.Error("following section was removed")
|
||||
}
|
||||
if !strings.Contains(content, "foo = \"bar\"") {
|
||||
t.Error("following section content was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("appends newline to file not ending with newline", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
existing := "[other]\nkey = \"val\""
|
||||
os.WriteFile(configPath, []byte(existing), 0o644)
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||
t.Error("missing [profiles.ollama-launch] header")
|
||||
}
|
||||
// Should not have double blank lines from missing trailing newline
|
||||
if strings.Contains(content, "\n\n\n") {
|
||||
t.Error("unexpected triple newline in output")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
catalogPath := filepath.Join(tmpDir, "model.json")
|
||||
|
||||
if err := writeCodexProfile(configPath, catalogPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if !strings.Contains(content, "myhost:9999/v1/") {
|
||||
t.Errorf("expected custom host in URL, got:\n%s", content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses connectable host for unspecified bind address", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.toml")
|
||||
|
||||
if err := writeCodexProfile(configPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if strings.Contains(content, "0.0.0.0") {
|
||||
t.Fatalf("config should not write bind-only host, got:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, "127.0.0.1:11434/v1/") {
|
||||
t.Fatalf("expected connectable loopback URL, got:\n%s", content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureCodexConfig(t *testing.T) {
|
||||
t.Run("creates .codex dir and config.toml", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("config.toml not created: %v", err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "[profiles.ollama-launch]") {
|
||||
t.Error("missing [profiles.ollama-launch] header")
|
||||
}
|
||||
if !strings.Contains(content, "openai_base_url") {
|
||||
t.Error("missing openai_base_url key")
|
||||
}
|
||||
|
||||
catalogPath := filepath.Join(tmpDir, ".codex", "model.json")
|
||||
data, err = os.ReadFile(catalogPath)
|
||||
if err != nil {
|
||||
t.Fatalf("model.json not created: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(data), `"slug": "llama3.2"`) {
|
||||
t.Error("missing model catalog entry for selected model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("is idempotent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := ensureCodexConfig("llama3.2", launchModelsFromNames([]string{"llama3.2"})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".codex", "config.toml")
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
if strings.Count(content, "[profiles.ollama-launch]") != 1 {
|
||||
t.Errorf("expected exactly one [profiles.ollama-launch] section after two calls, got %d", strings.Count(content, "[profiles.ollama-launch]"))
|
||||
}
|
||||
if strings.Count(content, "[model_providers.ollama-launch]") != 1 {
|
||||
t.Errorf("expected exactly one [model_providers.ollama-launch] section after two calls, got %d", strings.Count(content, "[model_providers.ollama-launch]"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func assertBackupContains(t *testing.T, pattern, marker string) {
|
||||
t.Helper()
|
||||
backups, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, backupPath := range backups {
|
||||
data, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(string(data), marker) {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("backup matching %q with marker %q not found", pattern, marker)
|
||||
}
|
||||
|
||||
func TestModelInfoContextLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelInfo map[string]any
|
||||
want int
|
||||
}{
|
||||
{"float64 value", map[string]any{"qwen3_5_moe.context_length": float64(262144)}, 262144},
|
||||
{"int value", map[string]any{"llama.context_length": 131072}, 131072},
|
||||
{"no context_length key", map[string]any{"llama.embedding_length": float64(4096)}, 0},
|
||||
{"empty map", map[string]any{}, 0},
|
||||
{"nil map", nil, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _ := modelInfoContextLength(tt.modelInfo)
|
||||
if got != tt.want {
|
||||
t.Errorf("modelInfoContextLength() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexModelEntryContextWindow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model LaunchModel
|
||||
envContextLen string
|
||||
wantContext int
|
||||
}{
|
||||
{
|
||||
name: "inventory context length as fallback",
|
||||
model: LaunchModel{
|
||||
Name: "llama3.2",
|
||||
ContextLength: 131072,
|
||||
Details: api.ModelDetails{Format: "gguf"},
|
||||
},
|
||||
wantContext: 131072,
|
||||
},
|
||||
{
|
||||
name: "details context length is used when model context is empty",
|
||||
model: LaunchModel{
|
||||
Name: "llama3.2",
|
||||
Details: api.ModelDetails{Format: "gguf", ContextLength: 131072},
|
||||
},
|
||||
wantContext: 131072,
|
||||
},
|
||||
{
|
||||
name: "OLLAMA_CONTEXT_LENGTH overrides local gguf inventory context",
|
||||
model: LaunchModel{
|
||||
Name: "llama3.2",
|
||||
ContextLength: 131072,
|
||||
Details: api.ModelDetails{Format: "gguf"},
|
||||
},
|
||||
envContextLen: "64000",
|
||||
wantContext: 64000,
|
||||
},
|
||||
{
|
||||
name: "safetensors uses inventory context only",
|
||||
model: LaunchModel{
|
||||
Name: "llama3.2",
|
||||
ContextLength: 131072,
|
||||
Details: api.ModelDetails{Format: "safetensors"},
|
||||
},
|
||||
envContextLen: "64000",
|
||||
wantContext: 131072,
|
||||
},
|
||||
{
|
||||
name: "cloud model uses hardcoded limits",
|
||||
model: LaunchModel{
|
||||
Name: "qwen3.5:cloud",
|
||||
ContextLength: 131072,
|
||||
Details: api.ModelDetails{Format: "gguf"},
|
||||
},
|
||||
envContextLen: "64000",
|
||||
wantContext: 262144,
|
||||
},
|
||||
{
|
||||
name: "unknown cloud model without metadata uses fallback context",
|
||||
model: LaunchModel{
|
||||
Name: "deepseek-v4-pro:cloud",
|
||||
},
|
||||
envContextLen: "64000",
|
||||
wantContext: codexFallbackContextWindow,
|
||||
},
|
||||
{
|
||||
name: "vision capability without reasoning advertisement",
|
||||
model: LaunchModel{
|
||||
Name: "llama3.2",
|
||||
ContextLength: 131072,
|
||||
Details: api.ModelDetails{Format: "gguf"},
|
||||
Capabilities: []modelpkg.Capability{modelpkg.CapabilityVision, modelpkg.CapabilityThinking},
|
||||
},
|
||||
wantContext: 131072,
|
||||
},
|
||||
{
|
||||
name: "missing metadata uses fallback context",
|
||||
model: LaunchModel{Name: "llama3.2"},
|
||||
wantContext: codexFallbackContextWindow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envContextLen != "" {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
|
||||
} else {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
|
||||
}
|
||||
|
||||
entry := buildCodexModelEntry(tt.model)
|
||||
|
||||
gotContext, _ := entry["context_window"].(int)
|
||||
if gotContext != tt.wantContext {
|
||||
t.Errorf("context_window = %d, want %d", gotContext, tt.wantContext)
|
||||
}
|
||||
|
||||
if tt.name == "vision capability without reasoning advertisement" {
|
||||
modalities, _ := entry["input_modalities"].([]string)
|
||||
if !slices.Contains(modalities, "image") {
|
||||
t.Error("expected image in input_modalities")
|
||||
}
|
||||
levels, _ := entry["supported_reasoning_levels"].([]any)
|
||||
if len(levels) != 0 {
|
||||
t.Errorf("supported_reasoning_levels length = %d, want 0", len(levels))
|
||||
}
|
||||
if got, _ := entry["supports_reasoning_summaries"].(bool); got {
|
||||
t.Error("supports_reasoning_summaries = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.name == "cloud model uses hardcoded limits" {
|
||||
truncationPolicy, _ := entry["truncation_policy"].(map[string]any)
|
||||
if mode, _ := truncationPolicy["mode"].(string); mode != "tokens" {
|
||||
t.Errorf("truncation_policy mode = %q, want %q", mode, "tokens")
|
||||
}
|
||||
}
|
||||
|
||||
requiredKeys := []string{"slug", "display_name", "shell_type"}
|
||||
for _, key := range requiredKeys {
|
||||
if _, ok := entry[key]; !ok {
|
||||
t.Errorf("missing required key %q", key)
|
||||
}
|
||||
}
|
||||
if _, ok := entry["apply_patch_tool_type"]; ok {
|
||||
t.Error("apply_patch_tool_type should be omitted so Codex CLI defaults can handle schema changes")
|
||||
}
|
||||
|
||||
if _, err := json.Marshal(entry); err != nil {
|
||||
t.Errorf("entry is not JSON serializable: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
689
cmd/launch/command_test.go
Normal file
689
cmd/launch/command_test.go
Normal file
@@ -0,0 +1,689 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func captureStderr(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
|
||||
oldStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stderr = w
|
||||
defer func() {
|
||||
os.Stderr = oldStderr
|
||||
}()
|
||||
|
||||
done := make(chan string, 1)
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, r)
|
||||
done <- buf.String()
|
||||
}()
|
||||
|
||||
fn()
|
||||
|
||||
_ = w.Close()
|
||||
return <-done
|
||||
}
|
||||
|
||||
func TestLaunchCmd(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
mockTUI := func(cmd *cobra.Command) {}
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
}
|
||||
if cmd.Long == "" {
|
||||
t.Error("Long description should not be empty")
|
||||
}
|
||||
if !strings.Contains(cmd.Long, "hermes") {
|
||||
t.Error("Long description should mention hermes")
|
||||
}
|
||||
if !strings.Contains(cmd.Long, "kimi") {
|
||||
t.Error("Long description should mention kimi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
if cmd.Flags().Lookup("model") == nil {
|
||||
t.Error("--model flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("config") == nil {
|
||||
t.Error("--config flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("restore") == nil {
|
||||
t.Error("--restore flag should exist")
|
||||
}
|
||||
if cmd.Flags().Lookup("yes") == nil {
|
||||
t.Error("--yes flag should exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreRunE is set", func(t *testing.T) {
|
||||
if cmd.PreRunE == nil {
|
||||
t.Error("PreRunE should be set to checkServerHeartbeat")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdTUICallback(t *testing.T) {
|
||||
mockCheck := func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no args calls TUI", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if !tuiCalled {
|
||||
t.Error("TUI callback should be called when no args provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integration arg bypasses TUI", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"claude"})
|
||||
_ = cmd.Execute()
|
||||
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when integration arg provided")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--model flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --model without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --model is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--config flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--config"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --config without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --config is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--yes flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--yes"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --yes without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --yes is provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extra args without integration return error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--model", "test-model", "--", "--sandbox", "workspace-write"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected flags and extra args without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when flags or extra args are provided without an integration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--restore flag without integration returns error", func(t *testing.T) {
|
||||
tuiCalled := false
|
||||
mockTUI := func(cmd *cobra.Command) {
|
||||
tuiCalled = true
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(mockCheck, mockTUI)
|
||||
cmd.SetArgs([]string{"--restore"})
|
||||
err := cmd.Execute()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected --restore without an integration to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "require an integration name") {
|
||||
t.Fatalf("expected integration-name guidance, got %v", err)
|
||||
}
|
||||
if tuiCalled {
|
||||
t.Error("TUI callback should NOT be called when --restore is provided without an integration")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLaunchCmdClaudeDesktopLaunchReturnsUnsupported(t *testing.T) {
|
||||
for _, name := range []string{"claude-desktop", "claude-app"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error {
|
||||
t.Fatal("heartbeat check should not run before Claude Desktop unsupported error")
|
||||
return nil
|
||||
}, func(cmd *cobra.Command) {
|
||||
t.Fatal("TUI callback should not run for direct integration launch")
|
||||
})
|
||||
cmd.SetArgs([]string{name})
|
||||
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected Claude Desktop launch command to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Claude Desktop is no longer supported") {
|
||||
t.Fatalf("expected unsupported guidance, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ollama launch claude-desktop --restore") {
|
||||
t.Fatalf("expected restore guidance, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdNilHeartbeat(t *testing.T) {
|
||||
cmd := LaunchCmd(nil, nil)
|
||||
if cmd == nil {
|
||||
t.Fatal("LaunchCmd returned nil")
|
||||
}
|
||||
if cmd.PreRunE != nil {
|
||||
t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/show":
|
||||
fmt.Fprintf(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubeditor")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var selectorCalls int
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
selectorCalls++
|
||||
gotCurrent = current
|
||||
return "llama3.2", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "glm-5:cloud"})
|
||||
stderr := captureStderr(t, func() {
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if selectorCalls != 1 {
|
||||
t.Fatalf("expected disabled cloud override to fall back to selector, got %d calls", selectorCalls)
|
||||
}
|
||||
if gotCurrent != "" {
|
||||
t.Fatalf("expected disabled override to be cleared before selection, got current %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with replacement local model, got %q", stub.ranModel)
|
||||
}
|
||||
if !strings.Contains(stderr, "Warning: ignoring --model glm-5:cloud because cloud is disabled") {
|
||||
t.Fatalf("expected disabled-cloud warning, got stderr: %q", stderr)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdAutodiscoveryDefaultLaunchDoesNotForceConfigure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
withLauncherHooks(t)
|
||||
|
||||
runner := &launcherManagedAutodiscoveryRunner{
|
||||
autodiscoveryConfigured: true,
|
||||
}
|
||||
restore := OverrideIntegration("stubauto", runner)
|
||||
defer restore()
|
||||
|
||||
if err := config.SaveIntegration("stubauto", []string{"Ollama Cloud"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
if err := config.MarkIntegrationOnboarded("stubauto"); err != nil {
|
||||
t.Fatalf("failed to mark integration onboarded: %v", err)
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {
|
||||
t.Fatal("TUI callback should not run for direct integration launch")
|
||||
})
|
||||
cmd.SetArgs([]string{"stubauto"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
if runner.autodiscoveryConfigures != 0 {
|
||||
t.Fatalf("expected default autodiscovery launch to reuse existing config, got %d configures", runner.autodiscoveryConfigures)
|
||||
}
|
||||
if runner.ranModel != "Ollama Cloud" {
|
||||
t.Fatalf("expected launch to run autodiscovery label, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdYes_AutoConfirmsLaunchPromptPath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
|
||||
t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithYes_AutoPullsMissingLocalModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
var pullCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"model not found"}`)
|
||||
case "/api/pull":
|
||||
pullCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"status":"success"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("unexpected prompt with --yes in headless autopull path: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--model", "missing-model", "--yes"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command with --yes failed: %v", err)
|
||||
}
|
||||
|
||||
if !pullCalled {
|
||||
t.Fatal("expected missing local model to be auto-pulled with --yes in headless mode")
|
||||
}
|
||||
if stub.ranModel != "missing-model" {
|
||||
t.Fatalf("expected launch to run with pulled model, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessWithoutYes_AllowsConfiguredLaunch(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
case "/api/status":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(w, `{"error":"not found"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherEditorRunner{paths: []string{"/tmp/stubeditor.json"}}
|
||||
restore := OverrideIntegration("stubeditor", stub)
|
||||
defer restore()
|
||||
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("unexpected prompt in headless non-yes mode: %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
|
||||
err := cmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("expected launch command to succeed without --yes when an explicit model is provided, got %v", err)
|
||||
}
|
||||
if diff := compareStringSlices(stub.edited, [][]string{{"llama3.2"}}); diff != "" {
|
||||
t.Fatalf("unexpected editor writes (-want +got):\n%s", diff)
|
||||
}
|
||||
if stub.ranModel != "llama3.2" {
|
||||
t.Fatalf("expected launch to run configured model, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"qwen3:8b"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
|
||||
var gotCurrent string
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
gotCurrent = current
|
||||
return "qwen3:8b", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp"})
|
||||
if err := cmd.Execute(); err != nil {
|
||||
t.Fatalf("launch command failed: %v", err)
|
||||
}
|
||||
|
||||
if gotCurrent != "llama3.2" {
|
||||
t.Fatalf("expected selector current model to be saved model llama3.2, got %q", gotCurrent)
|
||||
}
|
||||
if stub.ranModel != "qwen3:8b" {
|
||||
t.Fatalf("expected launch to run selected model qwen3:8b, got %q", stub.ranModel)
|
||||
}
|
||||
|
||||
saved, err := config.LoadIntegration("stubapp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload integration config: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff([]string{"qwen3:8b"}, saved.Models); diff != "" {
|
||||
t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessYes_IntegrationRequiresModelEvenWhenSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
if err := config.SaveIntegration("stubapp", []string{"llama3.2"}); err != nil {
|
||||
t.Fatalf("failed to seed saved config: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model":"llama3.2"}`)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes saved-model launch")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchCmdHeadlessYes_IntegrationWithoutSavedModelReturnsError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, false)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
stub := &launcherSingleRunner{}
|
||||
restore := OverrideIntegration("stubapp", stub)
|
||||
defer restore()
|
||||
|
||||
oldSelector := DefaultSingleSelector
|
||||
defer func() { DefaultSingleSelector = oldSelector }()
|
||||
DefaultSingleSelector = func(title string, items []SelectionItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called for headless --yes without saved model")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
|
||||
cmd.SetArgs([]string{"stubapp", "--yes"})
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("expected launch command to fail when --yes is used headlessly without --model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "requires --model <model>") {
|
||||
t.Fatalf("expected actionable --model guidance, got %v", err)
|
||||
}
|
||||
if stub.ranModel != "" {
|
||||
t.Fatalf("expected launch to abort before run, got %q", stub.ranModel)
|
||||
}
|
||||
}
|
||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||
type Copilot struct{}
|
||||
|
||||
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||
|
||||
func (c *Copilot) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Copilot) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("copilot"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".local", "bin", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Copilot) Run(model string, _ []LaunchModel, args []string) error {
|
||||
copilotPath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||
}
|
||||
|
||||
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// envVars returns the environment variables that configure Copilot CLI
|
||||
// to use Ollama as its model provider.
|
||||
func (c *Copilot) envVars(model string) []string {
|
||||
env := []string{
|
||||
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||
"COPILOT_PROVIDER_API_KEY=",
|
||||
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
env = append(env, "COPILOT_MODEL="+model)
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopilotIntegration(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "Copilot CLI" {
|
||||
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotFindPath(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
name := "copilot"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "copilot.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopilotArgs(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotEnvVars(t *testing.T) {
|
||||
c := &Copilot{}
|
||||
|
||||
envMap := func(envs []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, e := range envs {
|
||||
k, v, _ := strings.Cut(e, "=")
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||
got := envMap(c.envVars("llama3.2"))
|
||||
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||
}
|
||||
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||
}
|
||||
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||
}
|
||||
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||
got := envMap(c.envVars(""))
|
||||
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||
got := envMap(c.envVars("test"))
|
||||
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||
}
|
||||
})
|
||||
}
|
||||
188
cmd/launch/droid.go
Normal file
188
cmd/launch/droid.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
type Droid struct{}
|
||||
|
||||
// droidSettings represents the Droid settings.json file (only fields we use)
|
||||
type droidSettings struct {
|
||||
CustomModels []modelEntry `json:"customModels"`
|
||||
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
|
||||
}
|
||||
|
||||
type sessionSettings struct {
|
||||
Model string `json:"model"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
type modelEntry struct {
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Provider string `json:"provider"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string, _ []LaunchModel, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Droid) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(home, ".factory", "settings.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Droid) Edit(models []LaunchModel) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settingsPath := filepath.Join(home, ".factory", "settings.json")
|
||||
if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read file once, unmarshal twice:
|
||||
// map preserves unknown fields for writing back (including extra fields in model entries)
|
||||
settingsMap := make(map[string]any)
|
||||
var settings droidSettings
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
if err := json.Unmarshal(data, &settingsMap); err != nil {
|
||||
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
|
||||
}
|
||||
json.Unmarshal(data, &settings) // ignore error, zero values are fine
|
||||
}
|
||||
|
||||
settingsMap = updateDroidSettings(settingsMap, settings, models)
|
||||
|
||||
data, err := json.MarshalIndent(settingsMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(settingsPath, data, "droid")
|
||||
}
|
||||
|
||||
func updateDroidSettings(settingsMap map[string]any, settings droidSettings, models []LaunchModel) map[string]any {
|
||||
// Keep only non-Ollama models from the raw map (preserves extra fields)
|
||||
// Rebuild Ollama models
|
||||
var nonOllamaModels []any
|
||||
if rawModels, ok := settingsMap["customModels"].([]any); ok {
|
||||
for _, raw := range rawModels {
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
if m["apiKey"] != "ollama" {
|
||||
nonOllamaModels = append(nonOllamaModels, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
|
||||
|
||||
var newModels []any
|
||||
var defaultModelID string
|
||||
for i, model := range models {
|
||||
maxOutput := 64000
|
||||
if model.MaxOutputTokens > 0 {
|
||||
maxOutput = model.MaxOutputTokens
|
||||
}
|
||||
modelID := fmt.Sprintf("custom:%s-%d", model.Name, i)
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model.Name,
|
||||
DisplayName: model.Name,
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: maxOutput,
|
||||
SupportsImages: model.HasCapability("vision"),
|
||||
ID: modelID,
|
||||
Index: i,
|
||||
})
|
||||
if i == 0 {
|
||||
defaultModelID = modelID
|
||||
}
|
||||
}
|
||||
|
||||
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
|
||||
|
||||
// Update session default settings (preserve unknown fields in the nested object)
|
||||
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
|
||||
if !ok {
|
||||
sessionSettings = make(map[string]any)
|
||||
}
|
||||
sessionSettings["model"] = defaultModelID
|
||||
|
||||
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
|
||||
sessionSettings["reasoningEffort"] = "none"
|
||||
}
|
||||
|
||||
settingsMap["sessionDefaultSettings"] = sessionSettings
|
||||
return settingsMap
|
||||
}
|
||||
|
||||
func (d *Droid) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings droidSettings
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, m := range settings.CustomModels {
|
||||
if m.APIKey == "ollama" {
|
||||
result = append(result, m.Model)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var validReasoningEfforts = []string{"high", "medium", "low", "none"}
|
||||
|
||||
func isValidReasoningEffort(effort string) bool {
|
||||
return slices.Contains(validReasoningEfforts, effort)
|
||||
}
|
||||
1345
cmd/launch/droid_test.go
Normal file
1345
cmd/launch/droid_test.go
Normal file
File diff suppressed because it is too large
Load Diff
679
cmd/launch/hermes.go
Normal file
679
cmd/launch/hermes.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
|
||||
hermesProviderName = "Ollama"
|
||||
hermesProviderKey = "ollama-launch"
|
||||
hermesLegacyKey = "ollama"
|
||||
hermesPlaceholderKey = "ollama"
|
||||
hermesGatewaySetupHint = "hermes gateway setup"
|
||||
hermesGatewaySetupTitle = "Connect a messaging app now?"
|
||||
)
|
||||
|
||||
var (
|
||||
hermesGOOS = runtime.GOOS
|
||||
hermesLookPath = exec.LookPath
|
||||
hermesCommand = exec.Command
|
||||
hermesUserHome = os.UserHomeDir
|
||||
hermesOllamaURL = envconfig.ConnectableHost
|
||||
)
|
||||
|
||||
var hermesMessagingEnvGroups = [][]string{
|
||||
{"TELEGRAM_BOT_TOKEN"},
|
||||
{"DISCORD_BOT_TOKEN"},
|
||||
{"SLACK_BOT_TOKEN"},
|
||||
{"SIGNAL_ACCOUNT"},
|
||||
{"EMAIL_ADDRESS"},
|
||||
{"TWILIO_ACCOUNT_SID"},
|
||||
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
|
||||
{"MATTERMOST_TOKEN"},
|
||||
{"WHATSAPP_PHONE_NUMBER_ID"},
|
||||
{"DINGTALK_CLIENT_ID"},
|
||||
{"FEISHU_APP_ID"},
|
||||
{"WECOM_BOT_ID"},
|
||||
{"WEIXIN_ACCOUNT_ID"},
|
||||
{"BLUEBUBBLES_SERVER_URL"},
|
||||
{"WEBHOOK_ENABLED"},
|
||||
}
|
||||
|
||||
// Hermes is intentionally not an Editor integration: launch owns one primary
|
||||
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
|
||||
// switching UX after startup.
|
||||
type Hermes struct{}
|
||||
|
||||
func (h *Hermes) String() string { return "Hermes Agent" }
|
||||
|
||||
func (h *Hermes) Run(_ string, _ []LaunchModel, args []string) error {
|
||||
// Hermes reads its primary model from config.yaml. launch configures that
|
||||
// default model ahead of time so we can keep runtime invocation simple and
|
||||
// still let Hermes discover additional models later via its own UX.
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||
return hermesAttachedCommand(bin, "gateway", "setup").Run()
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return hermesAttachedCommand(bin, args...).Run()
|
||||
}
|
||||
|
||||
func (h *Hermes) Paths() []string {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return []string{configPath}
|
||||
}
|
||||
|
||||
func (h *Hermes) Configure(model string) error {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := map[string]any{}
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse hermes config: %w", err)
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
modelSection, _ := cfg["model"].(map[string]any)
|
||||
if modelSection == nil {
|
||||
modelSection = make(map[string]any)
|
||||
}
|
||||
models := h.listModels(model)
|
||||
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
|
||||
|
||||
// launch writes the minimum provider/default-model settings needed to
|
||||
// bootstrap Hermes against Ollama. The active provider stays on a
|
||||
// launch-owned key so /model stays aligned with the launcher-managed entry,
|
||||
// and the Ollama endpoint lives in providers: so the picker shows one row.
|
||||
modelSection["provider"] = hermesProviderKey
|
||||
modelSection["default"] = model
|
||||
modelSection["base_url"] = hermesBaseURL()
|
||||
modelSection["api_key"] = hermesPlaceholderKey
|
||||
cfg["model"] = modelSection
|
||||
|
||||
// use Hermes' built-in web toolset for now.
|
||||
// TODO(parthsareen): move this to using Ollama web search
|
||||
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, data, "hermes")
|
||||
}
|
||||
|
||||
func (h *Hermes) CurrentModel() string {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cfg := map[string]any{}
|
||||
if yaml.Unmarshal(data, &cfg) != nil {
|
||||
return ""
|
||||
}
|
||||
return hermesManagedCurrentModel(cfg, hermesBaseURL())
|
||||
}
|
||||
|
||||
func (h *Hermes) Onboard() error {
|
||||
return config.MarkIntegrationOnboarded("hermes")
|
||||
}
|
||||
|
||||
func (h *Hermes) RequiresInteractiveOnboarding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
|
||||
running, err := h.gatewayRunning()
|
||||
if err != nil {
|
||||
return fmt.Errorf("check Hermes gateway status: %w", err)
|
||||
}
|
||||
if !running {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
|
||||
if err := h.restartGateway(); err != nil {
|
||||
return fmt.Errorf("restart Hermes gateway: %w", err)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) installed() bool {
|
||||
_, err := h.binary()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (h *Hermes) ensureInstalled() error {
|
||||
if h.installed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return hermesWindowsHint()
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for _, dep := range []string{"bash", "curl", "git"} {
|
||||
if _, err := hermesLookPath(dep); err != nil {
|
||||
missing = append(missing, dep)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("hermes installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
|
||||
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
|
||||
return fmt.Errorf("failed to install hermes: %w", err)
|
||||
}
|
||||
|
||||
if !h.installed() {
|
||||
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) listModels(defaultModel string) []string {
|
||||
client := hermesOllamaClient()
|
||||
resp, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return []string{defaultModel}
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(resp.Models)+1)
|
||||
seen := make(map[string]struct{}, len(resp.Models)+1)
|
||||
add := func(name string) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
return
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
models = append(models, name)
|
||||
}
|
||||
|
||||
add(defaultModel)
|
||||
for _, entry := range resp.Models {
|
||||
add(entry.Name)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return []string{defaultModel}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (h *Hermes) binary() (string, error) {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return "", hermesWindowsHint()
|
||||
}
|
||||
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fallback := filepath.Join(home, ".local", "bin", "hermes")
|
||||
if _, err := os.Stat(fallback); err == nil {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("hermes is not installed")
|
||||
}
|
||||
|
||||
func hermesConfigPath() (string, error) {
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".hermes", "config.yaml"), nil
|
||||
}
|
||||
|
||||
func hermesBaseURL() string {
|
||||
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
|
||||
}
|
||||
|
||||
func hermesEnvPath() (string, error) {
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".hermes", ".env"), nil
|
||||
}
|
||||
|
||||
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
|
||||
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
|
||||
return nil
|
||||
}
|
||||
if h.messagingConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
|
||||
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
|
||||
YesLabel: "Yes",
|
||||
NoLabel: "Set up later",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := runSetup(); err != nil {
|
||||
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) messagingConfigured() bool {
|
||||
envVars, err := h.gatewayEnvVars()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if strings.TrimSpace(envVars[key]) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
envFilePath, err := hermesEnvPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch data, err := os.ReadFile(envFilePath); {
|
||||
case err == nil:
|
||||
for key, value := range hermesParseEnvFile(data) {
|
||||
envVars[key] = value
|
||||
}
|
||||
case os.IsNotExist(err):
|
||||
// nothing persisted yet
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
envVars[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayRunning() (bool, error) {
|
||||
status, err := h.gatewayStatusOutput()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hermesGatewayStatusRunning(status), nil
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
|
||||
return string(out), err
|
||||
}
|
||||
|
||||
func (h *Hermes) restartGateway() error {
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return hermesAttachedCommand(bin, "gateway", "restart").Run()
|
||||
}
|
||||
|
||||
func hermesGatewayStatusRunning(output string) bool {
|
||||
status := strings.ToLower(output)
|
||||
switch {
|
||||
case strings.Contains(status, "gateway is not running"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway service is stopped"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway service is not loaded"):
|
||||
return false
|
||||
case strings.Contains(status, "gateway is running"):
|
||||
return true
|
||||
case strings.Contains(status, "gateway service is running"):
|
||||
return true
|
||||
case strings.Contains(status, "gateway service is loaded"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func hermesParseEnvFile(data []byte) map[string]string {
|
||||
out := make(map[string]string)
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "export ") {
|
||||
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
|
||||
}
|
||||
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) >= 2 {
|
||||
switch {
|
||||
case value[0] == '"' && value[len(value)-1] == '"':
|
||||
if unquoted, err := strconv.Unquote(value); err == nil {
|
||||
value = unquoted
|
||||
}
|
||||
case value[0] == '\'' && value[len(value)-1] == '\'':
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hermesOllamaClient() *api.Client {
|
||||
// Hermes queries the same launch-resolved Ollama host that launch writes
|
||||
// into config, so model discovery follows the configured endpoint.
|
||||
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
|
||||
}
|
||||
|
||||
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
|
||||
providers := hermesUserProviders(cfg["providers"])
|
||||
entry := hermesManagedProviderEntry(providers)
|
||||
if entry == nil {
|
||||
entry = make(map[string]any)
|
||||
}
|
||||
entry["name"] = hermesProviderName
|
||||
entry["api"] = baseURL
|
||||
entry["default_model"] = model
|
||||
entry["models"] = hermesStringListAny(models)
|
||||
providers[hermesProviderKey] = entry
|
||||
delete(providers, hermesLegacyKey)
|
||||
cfg["providers"] = providers
|
||||
|
||||
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
|
||||
if len(customProviders) == 0 {
|
||||
delete(cfg, "custom_providers")
|
||||
return
|
||||
}
|
||||
cfg["custom_providers"] = customProviders
|
||||
}
|
||||
|
||||
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
|
||||
modelCfg, _ := cfg["model"].(map[string]any)
|
||||
if modelCfg == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
provider, _ := modelCfg["provider"].(string)
|
||||
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
|
||||
return ""
|
||||
}
|
||||
|
||||
configBaseURL, _ := modelCfg["base_url"].(string)
|
||||
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
|
||||
return ""
|
||||
}
|
||||
|
||||
current, _ := modelCfg["default"].(string)
|
||||
current = strings.TrimSpace(current)
|
||||
if current == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
providers := hermesUserProviders(cfg["providers"])
|
||||
entry, _ := providers[hermesProviderKey].(map[string]any)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
|
||||
return ""
|
||||
}
|
||||
|
||||
apiURL, _ := entry["api"].(string)
|
||||
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
|
||||
return ""
|
||||
}
|
||||
|
||||
defaultModel, _ := entry["default_model"].(string)
|
||||
if strings.TrimSpace(defaultModel) != current {
|
||||
return ""
|
||||
}
|
||||
|
||||
return current
|
||||
}
|
||||
|
||||
func hermesUserProviders(current any) map[string]any {
|
||||
switch existing := current.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(existing))
|
||||
for key, value := range existing {
|
||||
out[key] = value
|
||||
}
|
||||
return out
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(existing))
|
||||
for key, value := range existing {
|
||||
if s, ok := key.(string); ok {
|
||||
out[s] = value
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return make(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
func hermesCustomProviders(current any) []any {
|
||||
switch existing := current.(type) {
|
||||
case []any:
|
||||
return append([]any(nil), existing...)
|
||||
case []map[string]any:
|
||||
out := make([]any, 0, len(existing))
|
||||
for _, entry := range existing {
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
|
||||
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
|
||||
if entry, _ := providers[key].(map[string]any); entry != nil {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hermesWithoutManagedCustomProviders(current any) []any {
|
||||
customProviders := hermesCustomProviders(current)
|
||||
preserved := make([]any, 0, len(customProviders))
|
||||
|
||||
for _, item := range customProviders {
|
||||
entry, _ := item.(map[string]any)
|
||||
if entry == nil {
|
||||
preserved = append(preserved, item)
|
||||
continue
|
||||
}
|
||||
if hermesManagedCustomProvider(entry) {
|
||||
continue
|
||||
}
|
||||
preserved = append(preserved, entry)
|
||||
}
|
||||
|
||||
return preserved
|
||||
}
|
||||
|
||||
func hermesHasManagedCustomProvider(current any) bool {
|
||||
for _, item := range hermesCustomProviders(current) {
|
||||
entry, _ := item.(map[string]any)
|
||||
if entry != nil && hermesManagedCustomProvider(entry) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hermesManagedCustomProvider(entry map[string]any) bool {
|
||||
name, _ := entry["name"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
|
||||
}
|
||||
|
||||
func hermesNormalizeURL(raw string) string {
|
||||
return strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
}
|
||||
|
||||
func hermesStringListAny(models []string) []any {
|
||||
out := make([]any, 0, len(models))
|
||||
for _, model := range dedupeModelList(models) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, model)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeHermesToolsets(current any) any {
|
||||
added := false
|
||||
switch existing := current.(type) {
|
||||
case []any:
|
||||
out := make([]any, 0, len(existing)+1)
|
||||
for _, item := range existing {
|
||||
out = append(out, item)
|
||||
if s, _ := item.(string); s == "web" {
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
out = append(out, "web")
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := append([]string(nil), existing...)
|
||||
if !slices.Contains(out, "web") {
|
||||
out = append(out, "web")
|
||||
}
|
||||
asAny := make([]any, 0, len(out))
|
||||
for _, item := range out {
|
||||
asAny = append(asAny, item)
|
||||
}
|
||||
return asAny
|
||||
case string:
|
||||
if strings.TrimSpace(existing) == "" {
|
||||
return []any{"hermes-cli", "web"}
|
||||
}
|
||||
parts := strings.Split(existing, ",")
|
||||
out := make([]any, 0, len(parts)+1)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if part == "web" {
|
||||
added = true
|
||||
}
|
||||
out = append(out, part)
|
||||
}
|
||||
if !added {
|
||||
out = append(out, "web")
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return []any{"hermes-cli", "web"}
|
||||
}
|
||||
}
|
||||
|
||||
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||
cmd := hermesCommand(name, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd
|
||||
}
|
||||
|
||||
func hermesWindowsHint() error {
|
||||
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
|
||||
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
|
||||
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
|
||||
}
|
||||
1109
cmd/launch/hermes_test.go
Normal file
1109
cmd/launch/hermes_test.go
Normal file
File diff suppressed because it is too large
Load Diff
2071
cmd/launch/integrations_test.go
Normal file
2071
cmd/launch/integrations_test.go
Normal file
File diff suppressed because it is too large
Load Diff
315
cmd/launch/kimi.go
Normal file
315
cmd/launch/kimi.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Kimi implements Runner for Kimi Code CLI integration.
|
||||
type Kimi struct{}
|
||||
|
||||
const (
|
||||
kimiDefaultModelAlias = "ollama"
|
||||
kimiDefaultMaxContextSize = 32768
|
||||
)
|
||||
|
||||
var (
|
||||
kimiGOOS = runtime.GOOS
|
||||
kimiModelShowTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func (k *Kimi) String() string { return "Kimi Code CLI" }
|
||||
|
||||
func (k *Kimi) args(config string, extra []string) []string {
|
||||
args := []string{"--config", config}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (k *Kimi) Run(model string, _ []LaunchModel, args []string) error {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return fmt.Errorf("model is required")
|
||||
}
|
||||
if err := validateKimiPassthroughArgs(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build kimi config: %w", err)
|
||||
}
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, k.args(config, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func findKimiBinary() (string, error) {
|
||||
if path, err := exec.LookPath("kimi"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
home, _ := os.UserHomeDir()
|
||||
|
||||
var candidates []string
|
||||
switch kimiGOOS {
|
||||
case "windows":
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
|
||||
|
||||
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||
}
|
||||
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||
}
|
||||
default:
|
||||
candidates = append(candidates,
|
||||
filepath.Join(home, ".local", "bin", "kimi"),
|
||||
filepath.Join(home, "bin", "kimi"),
|
||||
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
|
||||
)
|
||||
|
||||
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
|
||||
candidates = append(candidates,
|
||||
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
|
||||
)
|
||||
}
|
||||
|
||||
// WSL users can inherit Windows env vars while launching from Linux shells.
|
||||
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
|
||||
}
|
||||
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||
}
|
||||
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||
}
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("kimi binary not found")
|
||||
}
|
||||
|
||||
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
return candidates
|
||||
}
|
||||
|
||||
return append(candidates,
|
||||
filepath.Join(dir, "kimi.exe"),
|
||||
filepath.Join(dir, "kimi.cmd"),
|
||||
filepath.Join(dir, "kimi.bat"),
|
||||
)
|
||||
}
|
||||
|
||||
func windowsPathToWSL(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if len(trimmed) < 3 || trimmed[1] != ':' {
|
||||
return ""
|
||||
}
|
||||
|
||||
drive := strings.ToLower(string(trimmed[0]))
|
||||
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
|
||||
rest = strings.TrimPrefix(rest, "/")
|
||||
if rest == "" {
|
||||
return filepath.Join("/mnt", drive)
|
||||
}
|
||||
|
||||
return filepath.Join("/mnt", drive, rest)
|
||||
}
|
||||
|
||||
func validateKimiPassthroughArgs(args []string) error {
|
||||
for _, arg := range args {
|
||||
switch {
|
||||
case arg == "--config", strings.HasPrefix(arg, "--config="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
|
||||
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
|
||||
case arg == "--model", strings.HasPrefix(arg, "--model="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
|
||||
case arg == "-m", strings.HasPrefix(arg, "-m="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
|
||||
cfg := map[string]any{
|
||||
"default_model": kimiDefaultModelAlias,
|
||||
"providers": map[string]any{
|
||||
kimiDefaultModelAlias: map[string]any{
|
||||
"type": "openai_legacy",
|
||||
"base_url": envconfig.ConnectableHost().String() + "/v1",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
},
|
||||
"models": map[string]any{
|
||||
kimiDefaultModelAlias: map[string]any{
|
||||
"provider": kimiDefaultModelAlias,
|
||||
"model": model,
|
||||
"max_context_size": maxContextSize,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func resolveKimiMaxContextSize(model string) int {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
return l.Context
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
|
||||
defer cancel()
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
|
||||
return n
|
||||
}
|
||||
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
|
||||
for key, val := range modelInfo {
|
||||
if !strings.HasSuffix(key, ".context_length") {
|
||||
continue
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
if v > 0 {
|
||||
return int(v), true
|
||||
}
|
||||
case int:
|
||||
if v > 0 {
|
||||
return v, true
|
||||
}
|
||||
case int64:
|
||||
if v > 0 {
|
||||
return int(v), true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func ensureKimiInstalled() (string, error) {
|
||||
if path, err := findKimiBinary(); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if err := checkKimiInstallerDependencies(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("kimi installation cancelled")
|
||||
}
|
||||
|
||||
bin, args, err := kimiInstallerCommand(kimiGOOS)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install kimi: %w", err)
|
||||
}
|
||||
|
||||
path, err := findKimiBinary()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func checkKimiInstallerDependencies() error {
|
||||
switch kimiGOOS {
|
||||
case "windows":
|
||||
if _, err := exec.LookPath("powershell"); err != nil {
|
||||
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
|
||||
}
|
||||
default:
|
||||
var missing []string
|
||||
if _, err := exec.LookPath("curl"); err != nil {
|
||||
missing = append(missing, "curl: https://curl.se/")
|
||||
}
|
||||
if _, err := exec.LookPath("bash"); err != nil {
|
||||
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kimiInstallerCommand(goos string) (string, []string, error) {
|
||||
switch goos {
|
||||
case "windows":
|
||||
return "powershell", []string{
|
||||
"-NoProfile",
|
||||
"-ExecutionPolicy",
|
||||
"Bypass",
|
||||
"-Command",
|
||||
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
|
||||
}, nil
|
||||
case "darwin", "linux":
|
||||
return "bash", []string{
|
||||
"-c",
|
||||
"curl -LsSf https://code.kimi.com/install.sh | bash",
|
||||
}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
|
||||
}
|
||||
}
|
||||
636
cmd/launch/kimi_test.go
Normal file
636
cmd/launch/kimi_test.go
Normal file
@@ -0,0 +1,636 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertKimiBinPath(t *testing.T, bin string) {
|
||||
t.Helper()
|
||||
base := strings.ToLower(filepath.Base(bin))
|
||||
if !strings.HasPrefix(base, "kimi") {
|
||||
t.Fatalf("bin = %q, want path to kimi executable", bin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKimiIntegration(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := k.String(); got != "Kimi Code CLI" {
|
||||
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = k
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiArgs(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
|
||||
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("args() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsPathToWSL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "user profile path",
|
||||
in: `C:\Users\parth`,
|
||||
want: filepath.Join("/mnt", "c", "Users", "parth"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "path with trailing slash",
|
||||
in: `D:\tools\bin\`,
|
||||
want: filepath.Join("/mnt", "d", "tools", "bin"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "non windows path",
|
||||
in: "/home/parth",
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := windowsPathToWSL(tt.in)
|
||||
if !tt.valid {
|
||||
if got != "" {
|
||||
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindKimiBinaryFallbacks(t *testing.T) {
|
||||
oldGOOS := kimiGOOS
|
||||
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||
|
||||
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
setTestHome(t, homeDir)
|
||||
t.Setenv("PATH", t.TempDir())
|
||||
kimiGOOS = "linux"
|
||||
|
||||
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
t.Fatalf("failed to create candidate dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||
}
|
||||
|
||||
got, err := findKimiBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("findKimiBinary() error = %v", err)
|
||||
}
|
||||
if got != target {
|
||||
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("windows appdata uv bin", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("PATH", t.TempDir())
|
||||
kimiGOOS = "windows"
|
||||
|
||||
appDataDir := t.TempDir()
|
||||
t.Setenv("APPDATA", appDataDir)
|
||||
t.Setenv("LOCALAPPDATA", "")
|
||||
|
||||
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
t.Fatalf("failed to create candidate dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||
}
|
||||
|
||||
got, err := findKimiBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("findKimiBinary() error = %v", err)
|
||||
}
|
||||
if got != target {
|
||||
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
want string
|
||||
}{
|
||||
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
|
||||
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
|
||||
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
|
||||
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
|
||||
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
|
||||
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
|
||||
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
|
||||
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateKimiPassthroughArgs(tt.args)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for args %v", tt.args)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKimiInlineConfig(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
|
||||
|
||||
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("config is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsed["default_model"] != "ollama" {
|
||||
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
|
||||
}
|
||||
|
||||
providers, ok := parsed["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||
}
|
||||
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||
}
|
||||
if ollamaProvider["type"] != "openai_legacy" {
|
||||
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||
}
|
||||
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
|
||||
}
|
||||
if ollamaProvider["api_key"] != "ollama" {
|
||||
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
|
||||
}
|
||||
|
||||
models, ok := parsed["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("models missing or wrong type: %T", parsed["models"])
|
||||
}
|
||||
ollamaModel, ok := models["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
|
||||
}
|
||||
if ollamaModel["provider"] != "ollama" {
|
||||
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
|
||||
}
|
||||
if ollamaModel["model"] != "llama3.2" {
|
||||
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
|
||||
}
|
||||
if ollamaModel["max_context_size"] != float64(65536) {
|
||||
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
|
||||
|
||||
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("config is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
providers, ok := parsed["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||
}
|
||||
|
||||
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||
}
|
||||
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
|
||||
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveKimiMaxContextSize(t *testing.T) {
|
||||
t.Run("uses cloud limit when known", func(t *testing.T) {
|
||||
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
|
||||
if got != 262_144 {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses model show context length for local models", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
got := resolveKimiMaxContextSize("llama3.2")
|
||||
if got != 131_072 {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to default when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
oldTimeout := kimiModelShowTimeout
|
||||
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
|
||||
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
|
||||
|
||||
got := resolveKimiMaxContextSize("llama3.2")
|
||||
if got != kimiDefaultMaxContextSize {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("did not expect install prompt, got %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||
|
||||
err := k.Run("llama3.2", nil, []string{"--model", "other"})
|
||||
if err == nil || !strings.Contains(err.Error(), "--model") {
|
||||
t.Fatalf("expected conflict error mentioning --model, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
logPath := filepath.Join(tmpDir, "kimi-args.log")
|
||||
script := fmt.Sprintf(`#!/bin/sh
|
||||
for arg in "$@"; do
|
||||
printf "%%s\n" "$arg" >> %q
|
||||
done
|
||||
exit 0
|
||||
`, logPath)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake kimi: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
k := &Kimi{}
|
||||
if err := k.Run("llama3.2", nil, []string{"--quiet", "--print"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read args log: %v", err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) < 4 {
|
||||
t.Fatalf("expected at least 4 args, got %v", lines)
|
||||
}
|
||||
if lines[0] != "--config" {
|
||||
t.Fatalf("first arg = %q, want --config", lines[0])
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
|
||||
t.Fatalf("config arg is not valid JSON: %v", err)
|
||||
}
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollamaProvider := providers["ollama"].(map[string]any)
|
||||
if ollamaProvider["type"] != "openai_legacy" {
|
||||
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||
}
|
||||
|
||||
if lines[2] != "--quiet" || lines[3] != "--print" {
|
||||
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureKimiInstalled(t *testing.T) {
|
||||
oldGOOS := kimiGOOS
|
||||
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||
|
||||
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
|
||||
t.Helper()
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return fn(prompt)
|
||||
}
|
||||
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||
}
|
||||
|
||||
t.Run("already installed", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
writeFakeBinary(t, tmpDir, "kimi")
|
||||
kimiGOOS = runtime.GOOS
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
assertKimiBinPath(t, bin)
|
||||
})
|
||||
|
||||
t.Run("missing dependencies", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
|
||||
t.Fatalf("expected missing dependency error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing and user declines install", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
writeFakeBinary(t, tmpDir, "bash")
|
||||
kimiGOOS = "linux"
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
if !strings.Contains(prompt, "Kimi is not installed.") {
|
||||
t.Fatalf("unexpected prompt: %q", prompt)
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
|
||||
t.Fatalf("expected cancellation error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
|
||||
installLog := filepath.Join(tmpDir, "bash.log")
|
||||
kimiPath := filepath.Join(tmpDir, "kimi")
|
||||
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||
echo "$@" >> %q
|
||||
if [ "$1" = "-c" ]; then
|
||||
/bin/cat > %q <<'EOS'
|
||||
#!/bin/sh
|
||||
exit 0
|
||||
EOS
|
||||
/bin/chmod +x %q
|
||||
fi
|
||||
exit 0
|
||||
`, installLog, kimiPath, kimiPath)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
assertKimiBinPath(t, bin)
|
||||
|
||||
logData, err := os.ReadFile(installLog)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read install log: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
|
||||
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
homeDir := t.TempDir()
|
||||
setTestHome(t, homeDir)
|
||||
|
||||
tmpBin := t.TempDir()
|
||||
t.Setenv("PATH", tmpBin)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpBin, "curl")
|
||||
|
||||
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
|
||||
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||
if [ "$1" = "-c" ]; then
|
||||
/bin/mkdir -p %q
|
||||
/bin/cat > %q <<'EOS'
|
||||
#!/bin/sh
|
||||
exit 0
|
||||
EOS
|
||||
/bin/chmod +x %q
|
||||
fi
|
||||
exit 0
|
||||
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
|
||||
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
if bin != installedKimi {
|
||||
t.Fatalf("bin = %q, want %q", bin, installedKimi)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install command fails", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
|
||||
t.Fatalf("expected install failure error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
|
||||
t.Fatalf("expected PATH guidance error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiInstallerCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goos string
|
||||
wantBin string
|
||||
wantParts []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "linux",
|
||||
goos: "linux",
|
||||
wantBin: "bash",
|
||||
wantParts: []string{"-c", "install.sh"},
|
||||
},
|
||||
{
|
||||
name: "darwin",
|
||||
goos: "darwin",
|
||||
wantBin: "bash",
|
||||
wantParts: []string{"-c", "install.sh"},
|
||||
},
|
||||
{
|
||||
name: "windows",
|
||||
goos: "windows",
|
||||
wantBin: "powershell",
|
||||
wantParts: []string{"-Command", "install.ps1"},
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
goos: "freebsd",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bin, args, err := kimiInstallerCommand(tt.goos)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("kimiInstallerCommand() error = %v", err)
|
||||
}
|
||||
if bin != tt.wantBin {
|
||||
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
|
||||
}
|
||||
joined := strings.Join(args, " ")
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(joined, part) {
|
||||
t.Fatalf("args %q missing %q", joined, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1495
cmd/launch/launch.go
Normal file
1495
cmd/launch/launch.go
Normal file
File diff suppressed because it is too large
Load Diff
3539
cmd/launch/launch_test.go
Normal file
3539
cmd/launch/launch_test.go
Normal file
File diff suppressed because it is too large
Load Diff
201
cmd/launch/model_inventory.go
Normal file
201
cmd/launch/model_inventory.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
modelpkg "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// LaunchModel is the model metadata Launch passes to integration config
|
||||
// writers after resolving selected model names through the per-run inventory.
|
||||
type LaunchModel struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
Capabilities []modelpkg.Capability
|
||||
ContextLength int
|
||||
MaxOutputTokens int
|
||||
EmbeddingLength int
|
||||
Size int64
|
||||
Details api.ModelDetails
|
||||
}
|
||||
|
||||
type modelInfo = LaunchModel
|
||||
|
||||
// ModelInfo re-exports launcher model inventory details for callers.
|
||||
type ModelInfo = LaunchModel
|
||||
|
||||
func (m LaunchModel) HasCapability(capability modelpkg.Capability) bool {
|
||||
return slices.Contains(m.Capabilities, capability)
|
||||
}
|
||||
|
||||
func (m LaunchModel) WithCloudLimits() LaunchModel {
|
||||
if limit, ok := lookupCloudModelLimit(m.Name); ok {
|
||||
if m.ContextLength <= 0 {
|
||||
m.ContextLength = limit.Context
|
||||
}
|
||||
if m.MaxOutputTokens <= 0 {
|
||||
m.MaxOutputTokens = limit.Output
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type modelInventory struct {
|
||||
client *api.Client
|
||||
|
||||
mu sync.Mutex
|
||||
loaded bool
|
||||
models []LaunchModel
|
||||
err error
|
||||
}
|
||||
|
||||
func newModelInventory(client *api.Client) *modelInventory {
|
||||
return &modelInventory{client: client}
|
||||
}
|
||||
|
||||
func (i *modelInventory) Load(ctx context.Context) ([]LaunchModel, error) {
|
||||
return i.load(ctx, false)
|
||||
}
|
||||
|
||||
func (i *modelInventory) Refresh(ctx context.Context) ([]LaunchModel, error) {
|
||||
return i.load(ctx, true)
|
||||
}
|
||||
|
||||
func (i *modelInventory) load(ctx context.Context, force bool) ([]LaunchModel, error) {
|
||||
if i == nil || i.client == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
if i.loaded && !force {
|
||||
return cloneLaunchModels(i.models), i.err
|
||||
}
|
||||
|
||||
resp, err := i.client.List(ctx)
|
||||
if err != nil {
|
||||
i.models = nil
|
||||
i.err = err
|
||||
i.loaded = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.models = make([]LaunchModel, 0, len(resp.Models))
|
||||
for _, model := range resp.Models {
|
||||
i.models = append(i.models, launchModelFromListResponse(model))
|
||||
}
|
||||
i.err = nil
|
||||
i.loaded = true
|
||||
|
||||
return cloneLaunchModels(i.models), i.err
|
||||
}
|
||||
|
||||
func (i *modelInventory) Resolve(ctx context.Context, names []string) []LaunchModel {
|
||||
names = dedupeModelList(names)
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
models, err := i.Load(ctx)
|
||||
if err != nil {
|
||||
models = nil
|
||||
}
|
||||
|
||||
resolved, localMiss := resolveLaunchModels(names, models)
|
||||
if localMiss {
|
||||
if refreshed, err := i.Refresh(ctx); err == nil {
|
||||
resolved, _ = resolveLaunchModels(names, refreshed)
|
||||
}
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
func resolveLaunchModels(names []string, models []LaunchModel) ([]LaunchModel, bool) {
|
||||
resolved := make([]LaunchModel, 0, len(names))
|
||||
localMiss := false
|
||||
for _, name := range names {
|
||||
if model, ok := findLaunchModel(models, name); ok {
|
||||
resolved = append(resolved, model.WithCloudLimits())
|
||||
continue
|
||||
}
|
||||
if !isCloudModelName(name) {
|
||||
localMiss = true
|
||||
}
|
||||
resolved = append(resolved, fallbackLaunchModel(name))
|
||||
}
|
||||
return resolved, localMiss
|
||||
}
|
||||
|
||||
func launchModelFromListResponse(model api.ListModelResponse) LaunchModel {
|
||||
return LaunchModel{
|
||||
Name: model.Name,
|
||||
Remote: model.RemoteModel != "",
|
||||
ToolCapable: slices.Contains(model.Capabilities, modelpkg.CapabilityTools),
|
||||
Capabilities: append([]modelpkg.Capability(nil), model.Capabilities...),
|
||||
ContextLength: model.Details.ContextLength,
|
||||
EmbeddingLength: model.Details.EmbeddingLength,
|
||||
Size: model.Size,
|
||||
Details: model.Details,
|
||||
}.WithCloudLimits()
|
||||
}
|
||||
|
||||
func fallbackLaunchModel(name string) LaunchModel {
|
||||
return LaunchModel{Name: name, Remote: isCloudModelName(name)}.WithCloudLimits()
|
||||
}
|
||||
|
||||
func findLaunchModel(models []LaunchModel, name string) (LaunchModel, bool) {
|
||||
for _, model := range models {
|
||||
if launchModelMatches(model.Name, name) {
|
||||
return cloneLaunchModel(model), true
|
||||
}
|
||||
}
|
||||
return LaunchModel{}, false
|
||||
}
|
||||
|
||||
func launchModelMatches(candidate, name string) bool {
|
||||
if candidate == name {
|
||||
return true
|
||||
}
|
||||
return strings.TrimSuffix(candidate, ":latest") == name
|
||||
}
|
||||
|
||||
func cloneLaunchModel(model LaunchModel) LaunchModel {
|
||||
model.Capabilities = append([]modelpkg.Capability(nil), model.Capabilities...)
|
||||
model.Details.Families = append([]string(nil), model.Details.Families...)
|
||||
return model
|
||||
}
|
||||
|
||||
func cloneLaunchModels(models []LaunchModel) []LaunchModel {
|
||||
cloned := make([]LaunchModel, len(models))
|
||||
for i, model := range models {
|
||||
cloned[i] = cloneLaunchModel(model)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func launchModelNames(models []LaunchModel) []string {
|
||||
names := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model.Name != "" {
|
||||
names = append(names, model.Name)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func launchModelsFromNames(names []string) []LaunchModel {
|
||||
models := make([]LaunchModel, 0, len(names))
|
||||
for _, name := range names {
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, fallbackLaunchModel(name))
|
||||
}
|
||||
return models
|
||||
}
|
||||
80
cmd/launch/model_inventory_test.go
Normal file
80
cmd/launch/model_inventory_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
modelpkg "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestModelInventoryResolveRefreshesLocalMiss(t *testing.T) {
|
||||
calls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/tags" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
calls++
|
||||
if calls == 1 {
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
return
|
||||
}
|
||||
fmt.Fprint(w, `{"models":[{"name":"new-model","size":123,"details":{"context_length":65536,"embedding_length":1024},"capabilities":["vision","tools"]}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
inventory := newModelInventory(api.NewClient(u, srv.Client()))
|
||||
|
||||
got := inventory.Resolve(context.Background(), []string{"new-model"})
|
||||
if calls != 2 {
|
||||
t.Fatalf("List calls = %d, want 2", calls)
|
||||
}
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("Resolve returned %d models, want 1", len(got))
|
||||
}
|
||||
if got[0].Name != "new-model" {
|
||||
t.Fatalf("Name = %q, want new-model", got[0].Name)
|
||||
}
|
||||
if got[0].ContextLength != 65_536 || got[0].EmbeddingLength != 1_024 {
|
||||
t.Fatalf("metadata = context %d embedding %d, want refreshed metadata", got[0].ContextLength, got[0].EmbeddingLength)
|
||||
}
|
||||
if !got[0].HasCapability(modelpkg.CapabilityVision) || !got[0].ToolCapable {
|
||||
t.Fatalf("capabilities = %v toolCapable=%v, want refreshed capabilities", got[0].Capabilities, got[0].ToolCapable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInventoryResolveDoesNotRefreshCloudMiss(t *testing.T) {
|
||||
calls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/tags" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
calls++
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
inventory := newModelInventory(api.NewClient(u, srv.Client()))
|
||||
|
||||
got := inventory.Resolve(context.Background(), []string{"glm-5.1:cloud"})
|
||||
if calls != 1 {
|
||||
t.Fatalf("List calls = %d, want 1", calls)
|
||||
}
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("Resolve returned %d models, want 1", len(got))
|
||||
}
|
||||
if got[0].Name != "glm-5.1:cloud" || !got[0].Remote {
|
||||
t.Fatalf("resolved model = %#v, want cloud fallback", got[0])
|
||||
}
|
||||
if got[0].ContextLength <= 0 || got[0].MaxOutputTokens <= 0 {
|
||||
t.Fatalf("cloud limits not applied: %#v", got[0])
|
||||
}
|
||||
}
|
||||
593
cmd/launch/models.go
Normal file
593
cmd/launch/models.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
"github.com/ollama/ollama/format"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/progress"
|
||||
)
|
||||
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true, Details: api.ModelDetails{ContextLength: 262_144}, MaxOutputTokens: 262_144},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true, Details: api.ModelDetails{ContextLength: 262_144}, MaxOutputTokens: 32_768},
|
||||
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true, Details: api.ModelDetails{ContextLength: 202_752}, MaxOutputTokens: 131_072},
|
||||
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true, Details: api.ModelDetails{ContextLength: 204_800}, MaxOutputTokens: 128_000},
|
||||
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true, VRAMBytes: 12 * format.GigaByte},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true, VRAMBytes: 14 * format.GigaByte},
|
||||
}
|
||||
|
||||
func displayVRAM(vramBytes int64) string {
|
||||
if vramBytes <= 0 {
|
||||
return ""
|
||||
}
|
||||
gb := float64(vramBytes) / format.GigaByte
|
||||
if gb == math.Trunc(gb) {
|
||||
return fmt.Sprintf("~%.0fGB", gb)
|
||||
}
|
||||
return fmt.Sprintf("~%.1fGB", gb)
|
||||
}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
type cloudModelLimit struct {
|
||||
Context int
|
||||
Output int
|
||||
}
|
||||
|
||||
// extraCloudModelLimits maps cloud model base names to token limits for models
|
||||
// that are not already covered by recommendedModels fallback entries.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var extraCloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
"gemma4:31b": {Context: 262_144, Output: 131_072},
|
||||
"glm-4.6": {Context: 202_752, Output: 131_072},
|
||||
"glm-4.7": {Context: 202_752, Output: 131_072},
|
||||
"glm-5": {Context: 202_752, Output: 131_072},
|
||||
"glm-5.1": {Context: 202_752, Output: 131_072},
|
||||
"gpt-oss:120b": {Context: 131_072, Output: 131_072},
|
||||
"gpt-oss:20b": {Context: 131_072, Output: 131_072},
|
||||
"kimi-k2:1t": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.5": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2.6": {Context: 262_144, Output: 262_144},
|
||||
"kimi-k2-thinking": {Context: 262_144, Output: 262_144},
|
||||
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
|
||||
"qwen3-coder:480b": {Context: 262_144, Output: 65_536},
|
||||
"qwen3-coder-next": {Context: 262_144, Output: 32_768},
|
||||
"qwen3-next:80b": {Context: 262_144, Output: 32_768},
|
||||
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
var cloudModelLimits = mergeCloudModelLimits(cloudModelLimitsFromRecommendations(recommendedModels), extraCloudModelLimits)
|
||||
|
||||
var (
|
||||
dynamicCloudModelLimitsMu sync.RWMutex
|
||||
dynamicCloudModelLimits = map[string]cloudModelLimit{}
|
||||
)
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
dynamicCloudModelLimitsMu.RLock()
|
||||
l, ok := dynamicCloudModelLimits[base]
|
||||
dynamicCloudModelLimitsMu.RUnlock()
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
}
|
||||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func setDynamicCloudModelLimits(limits map[string]cloudModelLimit) {
|
||||
dynamicCloudModelLimitsMu.Lock()
|
||||
defer dynamicCloudModelLimitsMu.Unlock()
|
||||
if limits == nil {
|
||||
dynamicCloudModelLimits = map[string]cloudModelLimit{}
|
||||
return
|
||||
}
|
||||
cp := make(map[string]cloudModelLimit, len(limits))
|
||||
for k, v := range limits {
|
||||
cp[k] = v
|
||||
}
|
||||
dynamicCloudModelLimits = cp
|
||||
}
|
||||
|
||||
func cloudModelLimitsFromRecommendations(recommendations []ModelItem) map[string]cloudModelLimit {
|
||||
limits := make(map[string]cloudModelLimit, len(recommendations))
|
||||
for _, rec := range recommendations {
|
||||
if !isCloudModelName(rec.Name) || rec.Details.ContextLength <= 0 || rec.MaxOutputTokens <= 0 {
|
||||
continue
|
||||
}
|
||||
base, stripped := modelref.StripCloudSourceTag(rec.Name)
|
||||
if !stripped || base == "" {
|
||||
continue
|
||||
}
|
||||
limits[base] = cloudModelLimit{
|
||||
Context: rec.Details.ContextLength,
|
||||
Output: rec.MaxOutputTokens,
|
||||
}
|
||||
}
|
||||
return limits
|
||||
}
|
||||
|
||||
func mergeCloudModelLimits(base map[string]cloudModelLimit, overlay map[string]cloudModelLimit) map[string]cloudModelLimit {
|
||||
out := make(map[string]cloudModelLimit, len(base)+len(overlay))
|
||||
for name, limit := range base {
|
||||
out[name] = limit
|
||||
}
|
||||
for name, limit := range overlay {
|
||||
out[name] = limit
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||
type missingModelPolicy int
|
||||
|
||||
const (
|
||||
// missingModelPromptPull prompts the user to download missing local models.
|
||||
missingModelPromptPull missingModelPolicy = iota
|
||||
// missingModelAutoPull downloads missing local models without prompting.
|
||||
missingModelAutoPull
|
||||
// missingModelFail returns an error for missing local models without prompting.
|
||||
missingModelFail
|
||||
)
|
||||
|
||||
// OpenBrowser opens the URL in the user's browser.
|
||||
func OpenBrowser(url string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
// Skip on headless systems where no display server is available
|
||||
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||
return
|
||||
}
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// ensureAuth ensures the user is signed in before cloud-backed models run.
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ensureCloudAuth(ctx, client, strings.Join(selectedCloudModels, ", "))
|
||||
}
|
||||
|
||||
func ensureCloudAuth(ctx context.Context, client *api.Client, modelList string) error {
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
|
||||
user, err := whoamiWithTimeout(ctx, client)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
if DefaultSignIn != nil {
|
||||
_, err := DefaultSignIn(modelList, aErr.SigninURL)
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yes, err := ConfirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if errors.Is(err, ErrCancelled) {
|
||||
return ErrCancelled
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !yes {
|
||||
return ErrCancelled
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
OpenBrowser(aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%10 == 0 {
|
||||
u, err := whoamiWithTimeout(ctx, client)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showOrPullWithPolicy checks if a model exists and applies the provided missing-model policy.
|
||||
func showOrPullWithPolicy(ctx context.Context, client *api.Client, model string, policy missingModelPolicy, isCloudModel bool) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if isCloudModel {
|
||||
if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
|
||||
return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
|
||||
}
|
||||
return fmt.Errorf("model %q not found", model)
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case missingModelAutoPull:
|
||||
return pullMissingModel(ctx, client, model)
|
||||
case missingModelFail:
|
||||
return fmt.Errorf("model %q not found; run 'ollama pull %s' first, or use --yes to auto-pull", model, model)
|
||||
default:
|
||||
return confirmAndPull(ctx, client, model)
|
||||
}
|
||||
}
|
||||
|
||||
func confirmAndPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if ok, err := ConfirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullMissingModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func pullMissingModel(ctx context.Context, client *api.Client, model string) error {
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||
func prepareEditorIntegration(name string, editor Editor, models []LaunchModel) error {
|
||||
if err := editor.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, launchModelNames(models)); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareManagedSingleIntegration(name string, managed ManagedSingleModel, model string, models []LaunchModel) error {
|
||||
var err error
|
||||
if withModels, ok := managed.(ManagedModelListConfigurer); ok {
|
||||
err = withModels.ConfigureWithModels(model, models)
|
||||
} else {
|
||||
err = managed.Configure(model)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareManagedAutodiscoveryIntegration(name string, autodiscovery ManagedAutodiscoveryIntegration, model string) error {
|
||||
if err := autodiscovery.ConfigureAutodiscovery(); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
if err := config.SaveIntegration(name, []string{model}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations for selection UIs.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
return buildModelListWithRecommendations(existing, recommendedModels, preChecked, current)
|
||||
}
|
||||
|
||||
func buildModelListWithRecommendations(existing []modelInfo, recommendations []ModelItem, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
recByName := make(map[string]ModelItem)
|
||||
for _, rec := range recommendations {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
recByName[rec.Name] = rec
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
if rec, ok := recByName[displayName]; ok {
|
||||
items = append(items, modelItemFromInventory(displayName, m, copyModelRecommendationFields(displayName, rec)))
|
||||
} else {
|
||||
items = append(items, modelItemFromInventory(displayName, m, ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rec := range recommendations {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if isCloudModelName(rec.Name) {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
if current != "" {
|
||||
matchedCurrent := false
|
||||
for _, item := range items {
|
||||
if item.Name == current {
|
||||
current = item.Name
|
||||
matchedCurrent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matchedCurrent {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] && !cloudModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
var parts []string
|
||||
if items[i].Description != "" {
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := displayVRAM(items[i].VRAMBytes); vram != "" {
|
||||
parts = append(parts, vram)
|
||||
}
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendations {
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
// Keep the Recommended section pinned to recommendation order. Checked
|
||||
// and default-model priority only apply within the More section.
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
|
||||
if aRec != bRec {
|
||||
if aRec {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if aRec && bRec {
|
||||
return recRank[a.Name] - recRank[b.Name]
|
||||
}
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
// Among checked non-recommended items - put the default first
|
||||
if ac && !aRec && current != "" {
|
||||
aCurrent := a.Name == current
|
||||
bCurrent := b.Name == current
|
||||
if aCurrent != bCurrent {
|
||||
if aCurrent {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
func copyModelRecommendationFields(name string, rec ModelItem) ModelItem {
|
||||
rec.Name = name
|
||||
rec.Recommended = true
|
||||
return rec
|
||||
}
|
||||
|
||||
func modelItemFromInventory(name string, info modelInfo, item ModelItem) ModelItem {
|
||||
item.Name = name
|
||||
item.ToolCapable = info.ToolCapable
|
||||
item.Capabilities = slices.Clone(info.Capabilities)
|
||||
item.Size = info.Size
|
||||
item.Details = info.Details
|
||||
return item
|
||||
}
|
||||
|
||||
// isCloudModelName reports whether the model name has an explicit cloud source.
|
||||
func isCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// filterCloudModels drops remote-only models from the given inventory.
|
||||
func filterCloudModels(existing []modelInfo) []modelInfo {
|
||||
filtered := existing[:0]
|
||||
for _, m := range existing {
|
||||
if !m.Remote {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterCloudItems removes cloud models from selection items.
|
||||
func filterCloudItems(items []ModelItem) []ModelItem {
|
||||
filtered := items[:0]
|
||||
for _, item := range items {
|
||||
if !isCloudModelName(item.Name) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.RemoteModel != ""
|
||||
}
|
||||
|
||||
// cloudStatusDisabled returns whether cloud usage is currently disabled.
|
||||
func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
|
||||
return false, false
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
83
cmd/launch/models_test.go
Normal file
83
cmd/launch/models_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
modelpkg "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestBuildModelList_UsesInventoryMetadataForInstalledModels(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{
|
||||
Name: "custom-tools:latest",
|
||||
ToolCapable: true,
|
||||
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion, modelpkg.CapabilityTools, modelpkg.CapabilityThinking},
|
||||
Size: 7500 * format.MegaByte,
|
||||
Details: api.ModelDetails{
|
||||
ParameterSize: "8B",
|
||||
QuantizationLevel: "Q4_K_M",
|
||||
ContextLength: 131_072,
|
||||
EmbeddingLength: 4096,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
var got ModelItem
|
||||
for _, item := range items {
|
||||
if item.Name == "custom-tools" {
|
||||
got = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if got.Name == "" {
|
||||
t.Fatal("custom-tools not found in items")
|
||||
}
|
||||
if !got.ToolCapable {
|
||||
t.Fatal("expected installed model to preserve tool capability from tags metadata")
|
||||
}
|
||||
if got.Details.ContextLength != 131_072 {
|
||||
t.Fatalf("Details.ContextLength = %d, want 131072", got.Details.ContextLength)
|
||||
}
|
||||
if got.Size != 7500*format.MegaByte {
|
||||
t.Fatalf("Size = %d, want %d", got.Size, 7500*format.MegaByte)
|
||||
}
|
||||
if got.Description != "" {
|
||||
t.Fatalf("Description = %q, want empty for installed model without recommendation copy", got.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_InstalledRecommendedPreservesRecommendationAndMetadata(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{
|
||||
Name: "qwen3.5",
|
||||
ToolCapable: true,
|
||||
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion, modelpkg.CapabilityTools, modelpkg.CapabilityVision},
|
||||
Size: 14 * format.GigaByte,
|
||||
Details: api.ModelDetails{ContextLength: 262_144},
|
||||
},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
var got ModelItem
|
||||
for _, item := range items {
|
||||
if item.Name == "qwen3.5" {
|
||||
got = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if got.Name == "" {
|
||||
t.Fatal("qwen3.5 not found in items")
|
||||
}
|
||||
if !got.Recommended || !got.ToolCapable {
|
||||
t.Fatalf("recommended/tool metadata = %v/%v, want true/true", got.Recommended, got.ToolCapable)
|
||||
}
|
||||
if got.Details.ContextLength != 262_144 {
|
||||
t.Fatalf("Details.ContextLength = %d, want 262144", got.Details.ContextLength)
|
||||
}
|
||||
if got.Description != "Reasoning, coding, and visual understanding locally" {
|
||||
t.Fatalf("Description = %q, want recommendation description", got.Description)
|
||||
}
|
||||
}
|
||||
991
cmd/launch/openclaw.go
Normal file
991
cmd/launch/openclaw.go
Normal file
@@ -0,0 +1,991 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const defaultGatewayPort = 18789
|
||||
|
||||
// openclawFreshInstall is set to true when ensureOpenclawInstalled performs an install
|
||||
var openclawFreshInstall bool
|
||||
|
||||
var openclawCanInstallDaemon = canInstallDaemon
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, _ []LaunchModel, args []string) error {
|
||||
bin, err := ensureOpenclawInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstLaunch := !c.onboarded()
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := ConfirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the latest version is installed before onboarding so we get
|
||||
// the newest wizard flags (e.g. --auth-choice ollama).
|
||||
if !openclawFreshInstall {
|
||||
update := exec.Command(bin, "update")
|
||||
update.Env = openclawInstallEnv()
|
||||
update.Stdout = os.Stdout
|
||||
update.Stderr = os.Stderr
|
||||
_ = update.Run() // best-effort; continue even if update fails
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
onboardArgs := []string{
|
||||
"onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "ollama",
|
||||
"--custom-base-url", envconfig.Host().String(),
|
||||
"--custom-model-id", model,
|
||||
// Launch owns the first real gateway startup immediately after onboarding,
|
||||
// so don't let OpenClaw fail the whole first-run flow on a transient
|
||||
// daemon health probe.
|
||||
"--skip-health",
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
}
|
||||
if openclawCanInstallDaemon() {
|
||||
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||
}
|
||||
cmd := exec.Command(bin, onboardArgs...)
|
||||
cmd.Env = openclawInstallEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
}
|
||||
|
||||
configureOllamaWebSearch()
|
||||
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
if len(args) > 0 {
|
||||
cleanup := func() {}
|
||||
if shouldEnsureGatewayForArgs(args) {
|
||||
cleanupFn, _, _, err := c.ensureGatewayReady(bin)
|
||||
if err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if cleanupFn != nil {
|
||||
cleanup = cleanupFn
|
||||
}
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.runChannelSetupPreflight(bin); err != nil {
|
||||
return err
|
||||
}
|
||||
// Keep local pairing scopes up to date before the gateway lifecycle
|
||||
// (restart/start) regardless of channel preflight branch behavior.
|
||||
patchDeviceScopes()
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
cleanup, token, port, err := c.ensureGatewayReady(bin)
|
||||
if err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
printOpenclawReady(bin, token, port, firstLaunch)
|
||||
|
||||
tuiArgs := []string{"tui"}
|
||||
if firstLaunch {
|
||||
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||
}
|
||||
tui := exec.Command(bin, tuiArgs...)
|
||||
tui.Env = openclawEnv()
|
||||
tui.Stdin = os.Stdin
|
||||
tui.Stdout = os.Stdout
|
||||
tui.Stderr = os.Stderr
|
||||
if err := tui.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldEnsureGatewayForArgs(args []string) bool {
|
||||
return len(args) > 0 && args[0] == "tui"
|
||||
}
|
||||
|
||||
func (c *Openclaw) ensureGatewayReady(bin string) (func(), string, int, error) {
|
||||
token, port := c.gatewayInfo()
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
if err := restart.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
if !waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// If the daemon is installed but not currently listening, try to bring it
|
||||
// up before falling back to a foreground child process.
|
||||
if openclawCanInstallDaemon() && !portOpen(addr) {
|
||||
start := exec.Command(bin, "daemon", "start")
|
||||
start.Env = openclawEnv()
|
||||
if err := start.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon start failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
} else if waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
return func() {}, token, port, nil
|
||||
}
|
||||
}
|
||||
|
||||
cleanup := func() {}
|
||||
|
||||
// If the gateway still isn't running, start it as a background child process.
|
||||
if !portOpen(addr) {
|
||||
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||
gw.Env = openclawEnv()
|
||||
if err := gw.Start(); err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to start gateway: %w", err)
|
||||
}
|
||||
cleanup = func() {
|
||||
if gw.Process != nil {
|
||||
_ = gw.Process.Kill()
|
||||
_ = gw.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
if !waitForPort(addr, 30*time.Second) {
|
||||
cleanup()
|
||||
return nil, "", 0, fmt.Errorf("gateway did not start on %s", addr)
|
||||
}
|
||||
|
||||
return cleanup, token, port, nil
|
||||
}
|
||||
|
||||
// runChannelSetupPreflight prompts users to connect a messaging channel before
|
||||
// starting the built-in gateway+TUI flow. In interactive sessions, it loops
|
||||
// until a channel is configured, unless the user chooses "Set up later".
|
||||
func (c *Openclaw) runChannelSetupPreflight(bin string) error {
|
||||
if !isInteractiveSession() {
|
||||
return nil
|
||||
}
|
||||
// --yes is headless; channel setup spawns an interactive picker we can't
|
||||
// auto-answer, so skip it. Users can run `openclaw channels add` later.
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if c.channelsConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nYour assistant can message you on WhatsApp, Telegram, Discord, and more.\n\n")
|
||||
ok, err := ConfirmPromptWithOptions("Connect a channel (messaging app) now?", ConfirmOptions{
|
||||
YesLabel: "Yes",
|
||||
NoLabel: "Set up later",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, "channels", "add")
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw channel setup failed: %w\n\nTry running: %s channels add", err, bin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// channelsConfigured reports whether local OpenClaw config contains at least
|
||||
// one meaningfully configured channel entry.
|
||||
func (c *Openclaw) channelsConfigured() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if json.Unmarshal(data, &cfg) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
channels, _ := cfg["channels"].(map[string]any)
|
||||
if channels == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for key, value := range channels {
|
||||
if key == "defaults" || key == "modelByChannel" {
|
||||
continue
|
||||
}
|
||||
entry, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for entryKey := range entry {
|
||||
if entryKey != "enabled" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||
port = defaultGatewayPort
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", port
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
continue
|
||||
}
|
||||
gw, _ := config["gateway"].(map[string]any)
|
||||
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||
port = int(p)
|
||||
}
|
||||
auth, _ := gw["auth"].(map[string]any)
|
||||
if t, _ := auth["token"].(string); t != "" {
|
||||
token = t
|
||||
}
|
||||
return token, port
|
||||
}
|
||||
return "", port
|
||||
}
|
||||
|
||||
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||
u := fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
if token != "" {
|
||||
u += "/#token=" + url.QueryEscape(token)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// openclawEnv returns the current environment with provider API keys cleared
|
||||
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||
func openclawEnv() []string {
|
||||
clear := map[string]bool{
|
||||
"ANTHROPIC_API_KEY": true,
|
||||
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||
"OPENAI_API_KEY": true,
|
||||
"GEMINI_API_KEY": true,
|
||||
"MISTRAL_API_KEY": true,
|
||||
"GROQ_API_KEY": true,
|
||||
"XAI_API_KEY": true,
|
||||
"OPENROUTER_API_KEY": true,
|
||||
}
|
||||
var env []string
|
||||
for _, e := range os.Environ() {
|
||||
key, _, _ := strings.Cut(e, "=")
|
||||
if !clear[key] {
|
||||
env = append(env, e)
|
||||
}
|
||||
}
|
||||
if _, ok := os.LookupEnv("OPENCLAW_PLUGIN_STAGE_DIR"); !ok {
|
||||
if dir := openclawPluginStageDir(); dir != "" {
|
||||
env = append(env, "OPENCLAW_PLUGIN_STAGE_DIR="+dir)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func openclawInstallEnv() []string {
|
||||
env := openclawEnv()
|
||||
if _, ok := os.LookupEnv("OPENCLAW_EAGER_BUNDLED_PLUGIN_DEPS"); !ok {
|
||||
env = append(env, "OPENCLAW_EAGER_BUNDLED_PLUGIN_DEPS=1")
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func openclawPluginStageDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".openclaw", "plugin-runtime-deps")
|
||||
}
|
||||
|
||||
// portOpen checks if a TCP port is currently accepting connections.
|
||||
func portOpen(addr string) bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func waitForPort(addr string, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func windowsHint(err error) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\n"+
|
||||
"OpenClaw runs best on WSL2.\n"+
|
||||
"Quick setup: wsl --install\n"+
|
||||
"Guide: https://docs.openclaw.ai/windows", err)
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
// patchDeviceScopes upgrades the local CLI device's paired operator scopes so
|
||||
// newer gateway auth baselines (approvedScopes) allow launch+TUI reconnects
|
||||
// without forcing an interactive re-pair. Only patches the local device,
|
||||
// not remote ones. Best-effort: silently returns on any error.
|
||||
func patchDeviceScopes() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := readLocalDeviceID(home)
|
||||
if deviceID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var devices map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &devices); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, ok := devices[deviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"operator.read",
|
||||
"operator.admin",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
changed := patchScopes(dev, "scopes", required)
|
||||
if patchScopes(dev, "approvedScopes", required) {
|
||||
changed = true
|
||||
}
|
||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||
for role, tok := range tokens {
|
||||
if tokenMap, ok := tok.(map[string]any); ok {
|
||||
if !isOperatorToken(role, tokenMap) {
|
||||
continue
|
||||
}
|
||||
if patchScopes(tokenMap, "scopes", required) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(devices, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||
func readLocalDeviceID(home string) string {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var auth map[string]any
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return ""
|
||||
}
|
||||
id, _ := auth["deviceId"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||
// any scopes were added.
|
||||
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
existing, _ := obj[key].([]any)
|
||||
have := make(map[string]bool, len(existing))
|
||||
for _, s := range existing {
|
||||
if str, ok := s.(string); ok {
|
||||
have[str] = true
|
||||
}
|
||||
}
|
||||
added := false
|
||||
for _, s := range required {
|
||||
if !have[s] {
|
||||
existing = append(existing, s)
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if added {
|
||||
obj[key] = existing
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
func isOperatorToken(tokenRole string, token map[string]any) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(tokenRole), "operator") {
|
||||
return true
|
||||
}
|
||||
role, _ := token["role"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(role), "operator")
|
||||
}
|
||||
|
||||
// canInstallDaemon reports whether the openclaw daemon can be installed as a
|
||||
// background service. Returns false on Linux when systemd is absent (e.g.
|
||||
// containers) so that --install-daemon is omitted and the gateway is started
|
||||
// as a foreground child process instead. Returns true in all other cases.
|
||||
func canInstallDaemon() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return true
|
||||
}
|
||||
// /run/systemd/system exists as a directory when systemd is the init system.
|
||||
// This is absent in most containers.
|
||||
fi, err := os.Stat("/run/systemd/system")
|
||||
if err != nil || !fi.IsDir() {
|
||||
return false
|
||||
}
|
||||
// Even when systemd is the init system, user services require a user
|
||||
// manager instance. XDG_RUNTIME_DIR being set is a prerequisite.
|
||||
return os.Getenv("XDG_RUNTIME_DIR") != ""
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
_, npmErr := exec.LookPath("npm")
|
||||
_, gitErr := exec.LookPath("git")
|
||||
if npmErr != nil || gitErr != nil {
|
||||
var missing []string
|
||||
if npmErr != nil {
|
||||
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||
}
|
||||
if gitErr != nil {
|
||||
missing = append(missing, "git: https://git-scm.com/")
|
||||
}
|
||||
return "", fmt.Errorf("OpenClaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch openclaw", strings.Join(missing, "\n "))
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("openclaw installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||
cmd.Env = openclawInstallEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
openclawFreshInstall = true
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return []string{legacy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Edit(models []LaunchModel) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String()
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
ollama["api"] = "ollama"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, m := range models {
|
||||
entry, _ := openclawModelConfig(m)
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[m.Name]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0].Name
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(configPath, data, "openclaw"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear any per-session model overrides so the new primary takes effect
|
||||
// immediately rather than being shadowed by a cached modelOverride.
|
||||
clearSessionModelOverride(models[0].Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearSessionModelOverride removes per-session model overrides from the main
|
||||
// agent session so the global primary model takes effect on the next TUI launch.
|
||||
func clearSessionModelOverride(primary string) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if json.Unmarshal(data, &sessions) != nil {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, sess := range sessions {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
}
|
||||
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
out, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// configureOllamaWebSearch keeps launch-managed OpenClaw installs on the
|
||||
// bundled Ollama web_search provider. Older launch builds installed an
|
||||
// external openclaw-web-search plugin that added custom ollama_web_search and
|
||||
// ollama_web_fetch tools. Current OpenClaw versions ship Ollama web_search as
|
||||
// the bundled "ollama" plugin instead, so we migrate stale config and ensure
|
||||
// fresh installs select the bundled provider.
|
||||
func configureOllamaWebSearch() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stalePluginConfigured := false
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
plugins = make(map[string]any)
|
||||
}
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
search, _ := web["search"].(map[string]any)
|
||||
if search == nil {
|
||||
search = make(map[string]any)
|
||||
}
|
||||
fetch, _ := web["fetch"].(map[string]any)
|
||||
if fetch == nil {
|
||||
fetch = make(map[string]any)
|
||||
}
|
||||
|
||||
alsoAllow, _ := tools["alsoAllow"].([]any)
|
||||
var filteredAlsoAllow []any
|
||||
for _, v := range alsoAllow {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
filteredAlsoAllow = append(filteredAlsoAllow, v)
|
||||
continue
|
||||
}
|
||||
if s == "ollama_web_search" || s == "ollama_web_fetch" {
|
||||
stalePluginConfigured = true
|
||||
continue
|
||||
}
|
||||
filteredAlsoAllow = append(filteredAlsoAllow, v)
|
||||
}
|
||||
if len(filteredAlsoAllow) > 0 {
|
||||
tools["alsoAllow"] = filteredAlsoAllow
|
||||
} else {
|
||||
delete(tools, "alsoAllow")
|
||||
}
|
||||
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
delete(entries, "openclaw-web-search")
|
||||
stalePluginConfigured = true
|
||||
}
|
||||
ollamaEntry, _ := entries["ollama"].(map[string]any)
|
||||
if ollamaEntry == nil {
|
||||
ollamaEntry = make(map[string]any)
|
||||
}
|
||||
ollamaEntry["enabled"] = true
|
||||
entries["ollama"] = ollamaEntry
|
||||
plugins["entries"] = entries
|
||||
|
||||
if allow, ok := plugins["allow"].([]any); ok {
|
||||
var nextAllow []any
|
||||
hasOllama := false
|
||||
for _, v := range allow {
|
||||
s, ok := v.(string)
|
||||
if ok && s == "openclaw-web-search" {
|
||||
stalePluginConfigured = true
|
||||
continue
|
||||
}
|
||||
if ok && s == "ollama" {
|
||||
hasOllama = true
|
||||
}
|
||||
nextAllow = append(nextAllow, v)
|
||||
}
|
||||
if !hasOllama {
|
||||
nextAllow = append(nextAllow, "ollama")
|
||||
}
|
||||
plugins["allow"] = nextAllow
|
||||
}
|
||||
|
||||
if installs, ok := plugins["installs"].(map[string]any); ok {
|
||||
if _, exists := installs["openclaw-web-search"]; exists {
|
||||
delete(installs, "openclaw-web-search")
|
||||
stalePluginConfigured = true
|
||||
}
|
||||
if len(installs) > 0 {
|
||||
plugins["installs"] = installs
|
||||
} else {
|
||||
delete(plugins, "installs")
|
||||
}
|
||||
}
|
||||
|
||||
if stalePluginConfigured || search["provider"] == nil {
|
||||
search["provider"] = "ollama"
|
||||
}
|
||||
if stalePluginConfigured {
|
||||
fetch["enabled"] = true
|
||||
}
|
||||
search["enabled"] = true
|
||||
web["search"] = search
|
||||
if len(fetch) > 0 {
|
||||
web["fetch"] = fetch
|
||||
}
|
||||
tools["web"] = web
|
||||
config["plugins"] = plugins
|
||||
config["tools"] = tools
|
||||
|
||||
out, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(configPath, out, 0o600)
|
||||
}
|
||||
|
||||
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||
// The second return value indicates whether the model is a cloud (remote) model.
|
||||
func openclawModelConfig(model LaunchModel) (map[string]any, bool) {
|
||||
entry := map[string]any{
|
||||
"id": model.Name,
|
||||
"name": model.Name,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if model.HasCapability("vision") {
|
||||
entry["input"] = []any{"text", "image"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if model.HasCapability("thinking") {
|
||||
entry["reasoning"] = true
|
||||
}
|
||||
|
||||
if model.ContextLength > 0 {
|
||||
entry["contextWindow"] = model.ContextLength
|
||||
}
|
||||
if model.MaxOutputTokens > 0 {
|
||||
entry["maxTokens"] = model.MaxOutputTokens
|
||||
}
|
||||
|
||||
return entry, model.Remote || isCloudModelName(model.Name)
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = fileutil.ReadJSON(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
2778
cmd/launch/openclaw_test.go
Normal file
2778
cmd/launch/openclaw_test.go
Normal file
File diff suppressed because it is too large
Load Diff
294
cmd/launch/opencode.go
Normal file
294
cmd/launch/opencode.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration.
|
||||
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
|
||||
// instead of writing to opencode's config files.
|
||||
type OpenCode struct {
|
||||
configContent string // JSON config built by Edit, passed to Run via env var
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
// findOpenCode returns the opencode binary path, checking PATH first then the
|
||||
// curl installer location (~/.opencode/bin) which may not be on PATH yet.
|
||||
func findOpenCode() (string, bool) {
|
||||
if p, err := exec.LookPath("opencode"); err == nil {
|
||||
return p, true
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
name := "opencode"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "opencode.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".opencode", "bin", name)
|
||||
if _, err := os.Stat(fallback); err == nil {
|
||||
return fallback, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (o *OpenCode) Run(model string, models []LaunchModel, args []string) error {
|
||||
opencodePath, ok := findOpenCode()
|
||||
if !ok {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
|
||||
cmd := exec.Command(opencodePath, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = os.Environ()
|
||||
if content := o.resolveContent(model, models); content != "" {
|
||||
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
|
||||
}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
|
||||
// Returns content built by Edit if available, otherwise builds from model.json
|
||||
// with the requested model as primary (e.g. re-launch with saved config).
|
||||
func (o *OpenCode) resolveContent(model string, models []LaunchModel) string {
|
||||
if o.configContent != "" {
|
||||
return o.configContent
|
||||
}
|
||||
resolvedModels := resolveOpenCodeRunModels(model, models, readModelJSONModels())
|
||||
if len(resolvedModels) == 0 {
|
||||
return ""
|
||||
}
|
||||
content, err := buildInlineConfig(resolvedModels[0], resolvedModels)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func resolveOpenCodeRunModels(primary string, models []LaunchModel, stateModels []string) []LaunchModel {
|
||||
if primary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
resolved := make([]LaunchModel, 0, 1+len(models)+len(stateModels))
|
||||
appendModel := func(name string) {
|
||||
if name == "" || hasLaunchModel(resolved, name) {
|
||||
return
|
||||
}
|
||||
if model, ok := findLaunchModel(models, name); ok {
|
||||
resolved = append(resolved, model)
|
||||
return
|
||||
}
|
||||
resolved = append(resolved, fallbackLaunchModel(name))
|
||||
}
|
||||
|
||||
appendModel(primary)
|
||||
for _, model := range models {
|
||||
appendModel(model.Name)
|
||||
}
|
||||
for _, model := range stateModels {
|
||||
appendModel(model)
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
func hasLaunchModel(models []LaunchModel, name string) bool {
|
||||
for _, model := range models {
|
||||
if launchModelMatches(model.Name, name) || launchModelMatches(name, model.Name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
sp, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
return []string{sp}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openCodeStatePath returns the path to opencode's model state file.
|
||||
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
|
||||
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
|
||||
func openCodeStatePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(models []LaunchModel) error {
|
||||
modelList := launchModelNames(models)
|
||||
if len(modelList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content, err := buildInlineConfig(models[0], models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.configContent = content
|
||||
|
||||
// Write model state file so models appear in OpenCode's model picker
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := map[string]any{
|
||||
"recent": []any{},
|
||||
"favorite": []any{},
|
||||
"variant": map[string]any{},
|
||||
}
|
||||
if data, err := os.ReadFile(statePath); err == nil {
|
||||
_ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
|
||||
}
|
||||
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
modelSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
modelSet[m] = true
|
||||
}
|
||||
|
||||
// Filter out existing Ollama models we're about to re-add
|
||||
newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok || e["providerID"] != "ollama" {
|
||||
return false
|
||||
}
|
||||
modelID, _ := e["modelID"].(string)
|
||||
return modelSet[modelID]
|
||||
})
|
||||
|
||||
// Prepend models in reverse order so first model ends up first
|
||||
for _, model := range slices.Backward(modelList) {
|
||||
newRecent = slices.Insert(newRecent, 0, any(map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": model,
|
||||
}))
|
||||
}
|
||||
|
||||
const maxRecentModels = 10
|
||||
newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
|
||||
|
||||
state["recent"] = newRecent
|
||||
|
||||
stateData, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(statePath, stateData, "opencode")
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
|
||||
// primary is the model to launch with, models is the full list of available models.
|
||||
func buildInlineConfig(primary LaunchModel, models []LaunchModel) (string, error) {
|
||||
if primary.Name == "" || len(models) == 0 {
|
||||
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
|
||||
}
|
||||
|
||||
config := map[string]any{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": map[string]any{
|
||||
"ollama": map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
"models": buildModelEntries(models),
|
||||
},
|
||||
},
|
||||
"model": "ollama/" + primary.Name,
|
||||
}
|
||||
data, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
|
||||
func readModelJSONModels() []string {
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil
|
||||
}
|
||||
recent, _ := state["recent"].([]any)
|
||||
var models []string
|
||||
for _, entry := range recent {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if e["providerID"] != "ollama" {
|
||||
continue
|
||||
}
|
||||
if id, ok := e["modelID"].(string); ok && id != "" {
|
||||
models = append(models, id)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func buildModelEntries(modelList []LaunchModel) map[string]any {
|
||||
models := make(map[string]any)
|
||||
for _, model := range modelList {
|
||||
entry := map[string]any{
|
||||
"name": model.Name,
|
||||
}
|
||||
if model.HasCapability("vision") {
|
||||
entry["modalities"] = map[string]any{
|
||||
"input": []string{"text", "image"},
|
||||
"output": []string{"text"},
|
||||
}
|
||||
}
|
||||
if model.ContextLength > 0 || model.MaxOutputTokens > 0 {
|
||||
limit := make(map[string]any)
|
||||
if model.ContextLength > 0 {
|
||||
limit["context"] = model.ContextLength
|
||||
}
|
||||
if model.MaxOutputTokens > 0 {
|
||||
limit["output"] = model.MaxOutputTokens
|
||||
}
|
||||
entry["limit"] = limit
|
||||
}
|
||||
models[model.Name] = entry
|
||||
}
|
||||
return models
|
||||
}
|
||||
862
cmd/launch/opencode_test.go
Normal file
862
cmd/launch/opencode_test.go
Normal file
@@ -0,0 +1,862 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestOpenCodeIntegration(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := o.String(); got != "OpenCode" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenCode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = o
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = o
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit(t *testing.T) {
|
||||
t.Run("builds config content with provider", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("llama3.2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
|
||||
t.Fatalf("configContent is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Verify provider structure
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name = %v, want Ollama", ollama["name"])
|
||||
}
|
||||
if ollama["npm"] != "@ai-sdk/openai-compatible" {
|
||||
t.Errorf("npm = %v, want @ai-sdk/openai-compatible", ollama["npm"])
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if models["llama3.2"] == nil {
|
||||
t.Error("model llama3.2 not found in config content")
|
||||
}
|
||||
|
||||
// Verify default model
|
||||
if cfg["model"] != "ollama/llama3.2" {
|
||||
t.Errorf("model = %v, want ollama/llama3.2", cfg["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple models", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("llama3.2", "qwen3:32b")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
|
||||
if models["llama3.2"] == nil {
|
||||
t.Error("model llama3.2 not found")
|
||||
}
|
||||
if models["qwen3:32b"] == nil {
|
||||
t.Error("model qwen3:32b not found")
|
||||
}
|
||||
// First model should be the default
|
||||
if cfg["model"] != "ollama/llama3.2" {
|
||||
t.Errorf("default model = %v, want ollama/llama3.2", cfg["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if o.configContent != "" {
|
||||
t.Errorf("expected empty configContent for no models, got %s", o.configContent)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not write config files", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
o := &OpenCode{}
|
||||
o.Edit(testLaunchModels("llama3.2"))
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".config", "opencode")
|
||||
|
||||
if _, err := os.Stat(filepath.Join(configDir, "opencode.json")); !os.IsNotExist(err) {
|
||||
t.Error("opencode.json should not be created")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(configDir, "opencode.jsonc")); !os.IsNotExist(err) {
|
||||
t.Error("opencode.jsonc should not be created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model has limits", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("glm-4.7:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
entry, _ := models["glm-4.7:cloud"].(map[string]any)
|
||||
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model should have limit set")
|
||||
}
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model has no limits", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
o := &OpenCode{}
|
||||
o.Edit(testLaunchModels("llama3.2"))
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
entry, _ := models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["limit"] != nil {
|
||||
t.Errorf("local model should not have limit, got %v", entry["limit"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vision model gets image input modalities", func(t *testing.T) {
|
||||
models := buildModelEntries([]LaunchModel{{Name: "gemma4:26b", Capabilities: []model.Capability{"vision"}}})
|
||||
entry, _ := models["gemma4:26b"].(map[string]any)
|
||||
modalities, _ := entry["modalities"].(map[string]any)
|
||||
input, _ := modalities["input"].([]string)
|
||||
output, _ := modalities["output"].([]string)
|
||||
|
||||
if len(input) != 2 || input[0] != "text" || input[1] != "image" {
|
||||
t.Fatalf("modalities.input = %v, want [text image]", input)
|
||||
}
|
||||
if len(output) != 1 || output[0] != "text" {
|
||||
t.Fatalf("modalities.output = %v, want [text]", output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildModelEntries(t *testing.T) {
|
||||
t.Run("defaults to model name without capabilities", func(t *testing.T) {
|
||||
models := buildModelEntries(testLaunchModels("llama3.2"))
|
||||
entry, _ := models["llama3.2"].(map[string]any)
|
||||
if entry["name"] != "llama3.2" {
|
||||
t.Fatalf("name = %v, want llama3.2", entry["name"])
|
||||
}
|
||||
if _, ok := entry["modalities"]; ok {
|
||||
t.Fatalf("modalities should not be set without capabilities, got %v", entry["modalities"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses context and output limits from metadata", func(t *testing.T) {
|
||||
models := buildModelEntries([]LaunchModel{{Name: "glm-5:cloud", ContextLength: 202_752, MaxOutputTokens: 131_072}})
|
||||
entry, _ := models["glm-5:cloud"].(map[string]any)
|
||||
limit, _ := entry["limit"].(map[string]any)
|
||||
if limit["context"] != 202_752 || limit["output"] != 131_072 {
|
||||
t.Fatalf("limit = %v, want context/output", limit)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeModels_ReturnsNil(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
if models := o.Models(); models != nil {
|
||||
t.Errorf("Models() = %v, want nil", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodePaths(t *testing.T) {
|
||||
t.Run("returns nil when model.json does not exist", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
o := &OpenCode{}
|
||||
if paths := o.Paths(); paths != nil {
|
||||
t.Errorf("Paths() = %v, want nil", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns model.json path when it exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
paths := o.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("Paths() returned %d paths, want 1", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(stateDir, "model.json") {
|
||||
t.Errorf("Paths() = %v, want %v", paths[0], filepath.Join(stateDir, "model.json"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLookupCloudModelLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantOK bool
|
||||
wantContext int
|
||||
wantOutput int
|
||||
}{
|
||||
{"glm-4.7", false, 0, 0},
|
||||
{"glm-4.7:cloud", true, 202_752, 131_072},
|
||||
{"glm-5:cloud", true, 202_752, 131_072},
|
||||
{"glm-5.1:cloud", true, 202_752, 131_072},
|
||||
{"gemma4:31b-cloud", true, 262_144, 131_072},
|
||||
{"gpt-oss:120b-cloud", true, 131_072, 131_072},
|
||||
{"gpt-oss:20b-cloud", true, 131_072, 131_072},
|
||||
{"kimi-k2.5", false, 0, 0},
|
||||
{"kimi-k2.5:cloud", true, 262_144, 262_144},
|
||||
{"deepseek-v3.2", false, 0, 0},
|
||||
{"deepseek-v3.2:cloud", true, 163_840, 65_536},
|
||||
{"qwen3.5", false, 0, 0},
|
||||
{"qwen3.5:cloud", true, 262_144, 32_768},
|
||||
{"qwen3-coder:480b", false, 0, 0},
|
||||
{"qwen3-coder:480b:cloud", true, 262_144, 65_536},
|
||||
{"qwen3-coder-next:cloud", true, 262_144, 32_768},
|
||||
{"llama3.2", false, 0, 0},
|
||||
{"unknown-model:cloud", false, 0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l, ok := lookupCloudModelLimit(tt.name)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
|
||||
}
|
||||
if ok {
|
||||
if l.Context != tt.wantContext {
|
||||
t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
|
||||
}
|
||||
if l.Output != tt.wantOutput {
|
||||
t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOpenCode(t *testing.T) {
|
||||
t.Run("fallback to ~/.opencode/bin", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Ensure opencode is not on PATH
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
// Without the fallback binary, findOpenCode should fail
|
||||
if _, ok := findOpenCode(); ok {
|
||||
t.Fatal("findOpenCode should fail when binary is not on PATH or in fallback location")
|
||||
}
|
||||
|
||||
// Create a fake binary at the curl install fallback location
|
||||
binDir := filepath.Join(tmpDir, ".opencode", "bin")
|
||||
os.MkdirAll(binDir, 0o755)
|
||||
name := "opencode"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "opencode.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(binDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
// Now findOpenCode should succeed via fallback
|
||||
path, ok := findOpenCode()
|
||||
if !ok {
|
||||
t.Fatal("findOpenCode should succeed with fallback binary")
|
||||
}
|
||||
if path != fakeBin {
|
||||
t.Errorf("findOpenCode = %q, want %q", path, fakeBin)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify that the BackfillsCloudModelLimitOnExistingEntry test from the old
|
||||
// file-based approach is covered by the new inline config approach.
|
||||
func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
expected := cloudModelLimits["glm-4.7"]
|
||||
|
||||
if err := o.Edit(testLaunchModels("glm-4.7:cloud")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
entry, _ := models["glm-4.7:cloud"].(map[string]any)
|
||||
|
||||
limit, ok := entry["limit"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("cloud model limit was not set")
|
||||
}
|
||||
if limit["context"] != float64(expected.Context) {
|
||||
t.Errorf("context = %v, want %d", limit["context"], expected.Context)
|
||||
}
|
||||
if limit["output"] != float64(expected.Output) {
|
||||
t.Errorf("output = %v, want %d", limit["output"], expected.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
specialModel := `model-with-"quotes"`
|
||||
|
||||
err := o.Edit(testLaunchModels(specialModel))
|
||||
if err != nil {
|
||||
t.Fatalf("Edit with special chars failed: %v", err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal([]byte(o.configContent), &cfg); err != nil {
|
||||
t.Fatalf("resulting config is invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if models[specialModel] == nil {
|
||||
t.Errorf("model with special chars not found in config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadModelJSONModels(t *testing.T) {
|
||||
t.Run("reads ollama models from model.json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
models := readModelJSONModels()
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("got %d models, want 2", len(models))
|
||||
}
|
||||
if models[0] != "llama3.2" || models[1] != "qwen3:32b" {
|
||||
t.Errorf("got %v, want [llama3.2 qwen3:32b]", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips non-ollama providers", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
models := readModelJSONModels()
|
||||
if len(models) != 1 || models[0] != "llama3.2" {
|
||||
t.Errorf("got %v, want [llama3.2]", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when file does not exist", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if models := readModelJSONModels(); models != nil {
|
||||
t.Errorf("got %v, want nil", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil for corrupt JSON", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
if models := readModelJSONModels(); models != nil {
|
||||
t.Errorf("got %v, want nil", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeResolveContent(t *testing.T) {
|
||||
t.Run("returns Edit's content when set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("gemma4")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
editContent := o.configContent
|
||||
|
||||
// Write a different model.json — should be ignored
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "different-model"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
got := o.resolveContent("gemma4", nil)
|
||||
if got != editContent {
|
||||
t.Errorf("resolveContent returned different content than Edit set\ngot: %s\nwant: %s", got, editContent)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to model.json when Edit was not called", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
content := o.resolveContent("llama3.2", nil)
|
||||
if content == "" {
|
||||
t.Fatal("resolveContent returned empty")
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(content), &cfg)
|
||||
if cfg["model"] != "ollama/llama3.2" {
|
||||
t.Errorf("primary = %v, want ollama/llama3.2", cfg["model"])
|
||||
}
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
cfgModels, _ := ollama["models"].(map[string]any)
|
||||
if cfgModels["llama3.2"] == nil || cfgModels["qwen3:32b"] == nil {
|
||||
t.Errorf("expected both models in config, got %v", cfgModels)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses requested model as primary even when not first in model.json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "qwen3:32b"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
content := o.resolveContent("qwen3:32b", nil)
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(content), &cfg)
|
||||
if cfg["model"] != "ollama/qwen3:32b" {
|
||||
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("injects requested model when missing from model.json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
content := o.resolveContent("gemma4", nil)
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(content), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
cfgModels, _ := ollama["models"].(map[string]any)
|
||||
if cfgModels["gemma4"] == nil {
|
||||
t.Error("requested model gemma4 not injected into config")
|
||||
}
|
||||
if cfg["model"] != "ollama/gemma4" {
|
||||
t.Errorf("primary = %v, want ollama/gemma4", cfg["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty when no model.json and no model param", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
o := &OpenCode{}
|
||||
if got := o.resolveContent("", nil); got != "" {
|
||||
t.Errorf("resolveContent(\"\") = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses run model metadata when Edit was not called", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
content := o.resolveContent("gemma4", []LaunchModel{
|
||||
{
|
||||
Name: "gemma4",
|
||||
Capabilities: []model.Capability{model.CapabilityVision},
|
||||
ContextLength: 65_536,
|
||||
MaxOutputTokens: 8_192,
|
||||
},
|
||||
})
|
||||
if content == "" {
|
||||
t.Fatal("resolveContent returned empty")
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(content), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
cfgModels, _ := ollama["models"].(map[string]any)
|
||||
entry, _ := cfgModels["gemma4"].(map[string]any)
|
||||
limit, _ := entry["limit"].(map[string]any)
|
||||
if limit["context"] != float64(65_536) || limit["output"] != float64(8_192) {
|
||||
t.Fatalf("limit = %v, want context/output from launch metadata", limit)
|
||||
}
|
||||
if _, ok := entry["modalities"].(map[string]any); !ok {
|
||||
t.Fatalf("modalities should be set from launch metadata, got %v", entry["modalities"])
|
||||
}
|
||||
if cfgModels["llama3.2"] == nil {
|
||||
t.Fatalf("state model missing from fallback config: %v", cfgModels)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not mutate configContent on fallback", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
state := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(state, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
_ = o.resolveContent("llama3.2", nil)
|
||||
if o.configContent != "" {
|
||||
t.Errorf("resolveContent should not mutate configContent, got %q", o.configContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildInlineConfig(t *testing.T) {
|
||||
t.Run("returns error for empty primary", func(t *testing.T) {
|
||||
if _, err := buildInlineConfig(LaunchModel{}, testLaunchModels("llama3.2")); err == nil {
|
||||
t.Error("expected error for empty primary")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for empty models", func(t *testing.T) {
|
||||
if _, err := buildInlineConfig(fallbackLaunchModel("llama3.2"), nil); err == nil {
|
||||
t.Error("expected error for empty models")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("primary differs from first model in list", func(t *testing.T) {
|
||||
content, err := buildInlineConfig(fallbackLaunchModel("qwen3:32b"), testLaunchModels("llama3.2", "qwen3:32b"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(content), &cfg)
|
||||
if cfg["model"] != "ollama/qwen3:32b" {
|
||||
t.Errorf("primary = %v, want ollama/qwen3:32b", cfg["model"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_PreservesRecentEntries(t *testing.T) {
|
||||
t.Run("prepends new models to existing recent", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
initial := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "old-A"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "old-B"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(initial, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("new-X")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||
var state map[string]any
|
||||
json.Unmarshal(stored, &state)
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
if len(recent) != 3 {
|
||||
t.Fatalf("expected 3 entries, got %d", len(recent))
|
||||
}
|
||||
first, _ := recent[0].(map[string]any)
|
||||
if first["modelID"] != "new-X" {
|
||||
t.Errorf("first entry = %v, want new-X", first["modelID"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prepends multiple new models in order", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
initial := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "old-A"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "old-B"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(initial, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("X", "Y", "Z")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||
var state map[string]any
|
||||
json.Unmarshal(stored, &state)
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
want := []string{"X", "Y", "Z", "old-A", "old-B"}
|
||||
if len(recent) != len(want) {
|
||||
t.Fatalf("expected %d entries, got %d", len(want), len(recent))
|
||||
}
|
||||
for i, w := range want {
|
||||
e, _ := recent[i].(map[string]any)
|
||||
if e["modelID"] != w {
|
||||
t.Errorf("recent[%d] = %v, want %v", i, e["modelID"], w)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves non-ollama entries", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
initial := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "openai", "modelID": "gpt-4"},
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(initial, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("qwen3:32b")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||
var state map[string]any
|
||||
json.Unmarshal(stored, &state)
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
// Should have: qwen3:32b (new), gpt-4 (preserved openai), llama3.2 (preserved ollama)
|
||||
var foundOpenAI bool
|
||||
for _, entry := range recent {
|
||||
e, _ := entry.(map[string]any)
|
||||
if e["providerID"] == "openai" && e["modelID"] == "gpt-4" {
|
||||
foundOpenAI = true
|
||||
}
|
||||
}
|
||||
if !foundOpenAI {
|
||||
t.Errorf("non-ollama gpt-4 entry was not preserved, got %v", recent)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deduplicates ollama models being re-added", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
initial := map[string]any{
|
||||
"recent": []any{
|
||||
map[string]any{"providerID": "ollama", "modelID": "llama3.2"},
|
||||
},
|
||||
}
|
||||
data, _ := json.MarshalIndent(initial, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("llama3.2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||
var state map[string]any
|
||||
json.Unmarshal(stored, &state)
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
count := 0
|
||||
for _, entry := range recent {
|
||||
e, _ := entry.(map[string]any)
|
||||
if e["modelID"] == "llama3.2" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 llama3.2 entry, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("caps recent list at 10", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
|
||||
os.MkdirAll(stateDir, 0o755)
|
||||
|
||||
// Pre-populate with 9 distinct ollama models
|
||||
recentEntries := make([]any, 0, 9)
|
||||
for i := range 9 {
|
||||
recentEntries = append(recentEntries, map[string]any{
|
||||
"providerID": "ollama",
|
||||
"modelID": fmt.Sprintf("old-%d", i),
|
||||
})
|
||||
}
|
||||
initial := map[string]any{"recent": recentEntries}
|
||||
data, _ := json.MarshalIndent(initial, "", " ")
|
||||
os.WriteFile(filepath.Join(stateDir, "model.json"), data, 0o644)
|
||||
|
||||
// Add 5 new models — should cap at 10 total
|
||||
o := &OpenCode{}
|
||||
if err := o.Edit(testLaunchModels("new-0", "new-1", "new-2", "new-3", "new-4")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stored, _ := os.ReadFile(filepath.Join(stateDir, "model.json"))
|
||||
var state map[string]any
|
||||
json.Unmarshal(stored, &state)
|
||||
recent, _ := state["recent"].([]any)
|
||||
|
||||
if len(recent) != 10 {
|
||||
t.Errorf("expected 10 entries (capped), got %d", len(recent))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenCodeEdit_BaseURL(t *testing.T) {
|
||||
o := &OpenCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Default OLLAMA_HOST
|
||||
o.Edit(testLaunchModels("llama3.2"))
|
||||
|
||||
var cfg map[string]any
|
||||
json.Unmarshal([]byte(o.configContent), &cfg)
|
||||
provider, _ := cfg["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
options, _ := ollama["options"].(map[string]any)
|
||||
|
||||
baseURL, _ := options["baseURL"].(string)
|
||||
if baseURL == "" {
|
||||
t.Error("baseURL should be set")
|
||||
}
|
||||
}
|
||||
365
cmd/launch/pi.go
Normal file
365
cmd/launch/pi.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
|
||||
type Pi struct{}
|
||||
|
||||
const (
|
||||
piNpmPackage = "@mariozechner/pi-coding-agent"
|
||||
piWebSearchSource = "npm:@ollama/pi-web-search"
|
||||
piWebSearchPkg = "@ollama/pi-web-search"
|
||||
)
|
||||
|
||||
func (p *Pi) String() string { return "Pi" }
|
||||
|
||||
func (p *Pi) Run(_ string, _ []LaunchModel, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing Pi...%s\n", ansiGray, ansiReset)
|
||||
if err := ensureNpmInstalled(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sChecking Pi installation...%s\n", ansiGray, ansiReset)
|
||||
bin, err := ensurePiInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ensurePiWebSearchPackage(bin)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sLaunching Pi...%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func ensureNpmInstalled() error {
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return fmt.Errorf("npm (Node.js) is required to launch pi\n\nInstall it first:\n https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensurePiInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("pi"); err == nil {
|
||||
return "pi", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("pi is not installed and required dependencies are missing\n\nInstall the following first:\n npm (Node.js): https://nodejs.org/\n\nThen re-run:\n ollama launch pi")
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Pi is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("pi installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Pi...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", piNpmPackage+"@latest")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install pi: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("pi"); err != nil {
|
||||
return "", fmt.Errorf("pi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sPi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "pi", nil
|
||||
}
|
||||
|
||||
func ensurePiWebSearchPackage(bin string) {
|
||||
if !shouldManagePiWebSearch() {
|
||||
fmt.Fprintf(os.Stderr, "%sCloud is disabled; skipping %s setup.%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sChecking Pi web search package...%s\n", ansiGray, ansiReset)
|
||||
|
||||
installed, err := piPackageInstalled(bin, piWebSearchSource)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not check %s installation: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
if !installed {
|
||||
fmt.Fprintf(os.Stderr, "%sInstalling %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
cmd := exec.Command(bin, "install", piWebSearchSource)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not install %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sUpdating %s...%s\n", ansiGray, piWebSearchPkg, ansiReset)
|
||||
cmd := exec.Command(bin, "update", piWebSearchSource)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update %s: %v%s\n", ansiYellow, piWebSearchPkg, err, ansiReset)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Updated %s%s\n", ansiGreen, piWebSearchPkg, ansiReset)
|
||||
}
|
||||
|
||||
func shouldManagePiWebSearch() bool {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
disabled, known := cloudStatusDisabled(context.Background(), client)
|
||||
if known && disabled {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func piPackageInstalled(bin, source string) (bool, error) {
|
||||
cmd := exec.Command(bin, "list")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if msg == "" {
|
||||
return false, err
|
||||
}
|
||||
return false, fmt.Errorf("%w: %s", err, msg)
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, source) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (p *Pi) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if _, err := os.Stat(modelsPath); err == nil {
|
||||
paths = append(paths, modelsPath)
|
||||
}
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
if _, err := os.Stat(settingsPath); err == nil {
|
||||
paths = append(paths, settingsPath)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func (p *Pi) Edit(models []LaunchModel) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
providers, ok := config["providers"].(map[string]any)
|
||||
if !ok {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"baseUrl": envconfig.Host().String() + "/v1",
|
||||
"api": "openai-completions",
|
||||
"apiKey": "ollama",
|
||||
}
|
||||
}
|
||||
|
||||
existingModels, ok := ollama["models"].([]any)
|
||||
if !ok {
|
||||
existingModels = make([]any, 0)
|
||||
}
|
||||
|
||||
// Build set of selected models to track which need to be added
|
||||
selectedSet := make(map[string]bool, len(models))
|
||||
for _, m := range models {
|
||||
selectedSet[m.Name] = true
|
||||
}
|
||||
|
||||
// Build new models list:
|
||||
// 1. Keep user-managed models (no _launch marker) - untouched
|
||||
// 2. Keep ollama-managed models (_launch marker) that are still selected,
|
||||
// except stale cloud entries that should be rebuilt below
|
||||
// 3. Add new ollama-managed models
|
||||
var newModels []any
|
||||
for _, m := range existingModels {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
// User-managed model (no _launch marker) - always preserve
|
||||
if !isPiOllamaModel(modelObj) {
|
||||
newModels = append(newModels, m)
|
||||
} else if selectedSet[id] {
|
||||
// Rebuild stale managed cloud entries so createConfig refreshes
|
||||
// the whole entry instead of patching it in place.
|
||||
if !hasContextWindow(modelObj) {
|
||||
if _, ok := lookupCloudModelLimit(id); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, m)
|
||||
selectedSet[id] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add newly selected models that weren't already in the list
|
||||
for _, model := range models {
|
||||
if selectedSet[model.Name] {
|
||||
newModels = append(newModels, createConfig(model))
|
||||
}
|
||||
}
|
||||
|
||||
ollama["models"] = newModels
|
||||
providers["ollama"] = ollama
|
||||
config["providers"] = providers
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(configPath, configData, "pi"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update settings.json with default provider and model
|
||||
settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
|
||||
settings := make(map[string]any)
|
||||
if data, err := os.ReadFile(settingsPath); err == nil {
|
||||
_ = json.Unmarshal(data, &settings)
|
||||
}
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0].Name
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(settingsPath, settingsData, "pi")
|
||||
}
|
||||
|
||||
func (p *Pi) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".pi", "agent", "models.json")
|
||||
config, err := fileutil.ReadJSON(configPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers, _ := config["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range models {
|
||||
if modelObj, ok := m.(map[string]any); ok {
|
||||
if id, ok := modelObj["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// isPiOllamaModel reports whether a model config entry is managed by ollama launch
|
||||
func isPiOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasContextWindow(cfg map[string]any) bool {
|
||||
switch v := cfg["contextWindow"].(type) {
|
||||
case float64:
|
||||
return v > 0
|
||||
case int:
|
||||
return v > 0
|
||||
case int64:
|
||||
return v > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// createConfig builds Pi model config with capability detection.
|
||||
func createConfig(model LaunchModel) map[string]any {
|
||||
cfg := map[string]any{
|
||||
"id": model.Name,
|
||||
"_launch": true,
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if model.HasCapability("vision") {
|
||||
cfg["input"] = []string{"text", "image"}
|
||||
} else {
|
||||
cfg["input"] = []string{"text"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if model.HasCapability("thinking") {
|
||||
cfg["reasoning"] = true
|
||||
}
|
||||
|
||||
if model.ContextLength > 0 {
|
||||
cfg["contextWindow"] = model.ContextLength
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
1210
cmd/launch/pi_test.go
Normal file
1210
cmd/launch/pi_test.go
Normal file
File diff suppressed because it is too large
Load Diff
51
cmd/launch/poolside.go
Normal file
51
cmd/launch/poolside.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Poolside implements Runner for Poolside's CLI.
|
||||
type Poolside struct{}
|
||||
|
||||
var poolsideGOOS = runtime.GOOS
|
||||
|
||||
func (p *Poolside) String() string { return "Pool" }
|
||||
|
||||
func poolsideUnsupportedError() error {
|
||||
return fmt.Errorf("Warning: Poolside is not currently supported on Windows")
|
||||
}
|
||||
|
||||
func (p *Poolside) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (p *Poolside) Run(model string, _ []LaunchModel, args []string) error {
|
||||
if poolsideGOOS == "windows" {
|
||||
return poolsideUnsupportedError()
|
||||
}
|
||||
|
||||
bin, err := exec.LookPath("pool")
|
||||
if err != nil {
|
||||
return fmt.Errorf("pool is not installed")
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, p.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"POOLSIDE_STANDALONE_BASE_URL="+envconfig.Host().String()+"/v1",
|
||||
"POOLSIDE_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
88
cmd/launch/poolside_test.go
Normal file
88
cmd/launch/poolside_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPoolsideArgs(t *testing.T) {
|
||||
p := &Poolside{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
extra []string
|
||||
want []string
|
||||
}{
|
||||
{name: "with model", model: "qwen3.5", want: []string{"-m", "qwen3.5"}},
|
||||
{name: "without model", extra: []string{"session"}, want: []string{"session"}},
|
||||
{name: "with model and extra args", model: "llama3.2", extra: []string{"--foo", "bar"}, want: []string{"-m", "llama3.2", "--foo", "bar"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := p.args(tt.model, tt.extra)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Fatalf("args(%q, %v) = %v, want %v", tt.model, tt.extra, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolsideRunSetsOllamaEnv(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "pool.log")
|
||||
poolPath := filepath.Join(tmpDir, "pool")
|
||||
script := "#!/bin/sh\n" +
|
||||
"printf 'base=%s\\nkey=%s\\nargs=%s\\n' \"$POOLSIDE_STANDALONE_BASE_URL\" \"$POOLSIDE_API_KEY\" \"$*\" > \"" + logPath + "\"\n"
|
||||
if err := os.WriteFile(poolPath, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake pool binary: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PATH", tmpDir)
|
||||
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
|
||||
|
||||
p := &Poolside{}
|
||||
if err := p.Run("qwen3.5", nil, []string{"session"}); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read pool log: %v", err)
|
||||
}
|
||||
|
||||
got := string(data)
|
||||
if !strings.Contains(got, "base=http://127.0.0.1:11434/v1") {
|
||||
t.Fatalf("expected Poolside base URL override in log, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "key=ollama") {
|
||||
t.Fatalf("expected Poolside API key override in log, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "args=-m qwen3.5 session") {
|
||||
t.Fatalf("expected model and extra args in log, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolsideRunWindowsUnsupported(t *testing.T) {
|
||||
prev := poolsideGOOS
|
||||
poolsideGOOS = "windows"
|
||||
t.Cleanup(func() { poolsideGOOS = prev })
|
||||
|
||||
p := &Poolside{}
|
||||
err := p.Run("kimi-k2.6:cloud", nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected Windows unsupported error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not currently supported on Windows") {
|
||||
t.Fatalf("expected Windows warning, got %v", err)
|
||||
}
|
||||
}
|
||||
470
cmd/launch/registry.go
Normal file
470
cmd/launch/registry.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IntegrationInstallSpec describes how launcher should detect and guide installation.
|
||||
type IntegrationInstallSpec struct {
|
||||
CheckInstalled func() bool
|
||||
EnsureInstalled func() error
|
||||
URL string
|
||||
Command []string
|
||||
}
|
||||
|
||||
// IntegrationSpec is the canonical registry entry for one integration.
|
||||
type IntegrationSpec struct {
|
||||
Name string
|
||||
Runner Runner
|
||||
Aliases []string
|
||||
Hidden bool
|
||||
Description string
|
||||
Install IntegrationInstallSpec
|
||||
}
|
||||
|
||||
// IntegrationInfo contains display information about a registered integration.
|
||||
type IntegrationInfo struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
Description string
|
||||
}
|
||||
|
||||
var launcherIntegrationOrder = []string{"claude", "codex-app", "hermes", "openclaw", "opencode", "codex", "copilot", "droid", "pi", "pool"}
|
||||
|
||||
var integrationSpecs = []*IntegrationSpec{
|
||||
{
|
||||
Name: "claude",
|
||||
Runner: &Claude{},
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Claude{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://code.claude.com/docs/en/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "claude-desktop",
|
||||
Runner: &ClaudeDesktop{},
|
||||
Aliases: []string{"claude-app"},
|
||||
Description: "Claude Desktop with Ollama Cloud",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return claudeDesktopInstalled()
|
||||
},
|
||||
URL: "https://claude.com/download",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cline",
|
||||
Runner: &Cline{},
|
||||
Description: "Autonomous coding agent with parallel execution",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("cline")
|
||||
return err == nil
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "cline"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "codex",
|
||||
Runner: &Codex{},
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("codex")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://developers.openai.com/codex/cli/",
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "codex-app",
|
||||
Runner: &CodexApp{},
|
||||
Aliases: []string{"codex-desktop", "codex-gui"},
|
||||
Description: "An AI agent you can delegate real work to, by OpenAI",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return codexAppInstalled()
|
||||
},
|
||||
URL: "https://developers.openai.com/codex/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "kimi",
|
||||
Runner: &Kimi{},
|
||||
Description: "Moonshot's coding agent for terminal and IDEs",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("kimi")
|
||||
return err == nil
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureKimiInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "copilot",
|
||||
Runner: &Copilot{},
|
||||
Aliases: []string{"copilot-cli"},
|
||||
Description: "GitHub's AI coding agent for the terminal",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := (&Copilot{}).findPath()
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://github.com/features/copilot/cli/",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "droid",
|
||||
Runner: &Droid{},
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("droid")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "opencode",
|
||||
Runner: &OpenCode{},
|
||||
Description: "Anomaly's open-source coding agent",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, ok := findOpenCode()
|
||||
return ok
|
||||
},
|
||||
URL: "https://opencode.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "openclaw",
|
||||
Runner: &Openclaw{},
|
||||
Aliases: []string{"clawdbot", "moltbot"},
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://docs.openclaw.ai",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pi",
|
||||
Runner: &Pi{},
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pi")
|
||||
return err == nil
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensurePiInstalled()
|
||||
return err
|
||||
},
|
||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "pool",
|
||||
Runner: &Poolside{},
|
||||
Description: "Poolside's software agent for enterprise development",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("pool")
|
||||
return err == nil
|
||||
},
|
||||
URL: "https://github.com/poolsideai/pool",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hermes",
|
||||
Runner: &Hermes{},
|
||||
Description: "Self-improving AI agent built by Nous Research",
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return (&Hermes{}).installed()
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
return (&Hermes{}).ensureInstalled()
|
||||
},
|
||||
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vscode",
|
||||
Runner: &VSCode{},
|
||||
Aliases: []string{"code"},
|
||||
Description: "Microsoft's open-source AI code editor",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
return (&VSCode{}).findBinary() != ""
|
||||
},
|
||||
URL: "https://code.visualstudio.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var integrationSpecsByName map[string]*IntegrationSpec
|
||||
|
||||
func init() {
|
||||
rebuildIntegrationSpecIndexes()
|
||||
}
|
||||
|
||||
func hyperlink(url, text string) string {
|
||||
return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
|
||||
}
|
||||
|
||||
func rebuildIntegrationSpecIndexes() {
|
||||
integrationSpecsByName = make(map[string]*IntegrationSpec, len(integrationSpecs))
|
||||
|
||||
canonical := make(map[string]bool, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
key := strings.ToLower(spec.Name)
|
||||
if key == "" {
|
||||
panic("launch: integration spec missing name")
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate integration name %q", key))
|
||||
}
|
||||
canonical[key] = true
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
|
||||
seenAliases := make(map[string]string)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
key := strings.ToLower(alias)
|
||||
if key == "" {
|
||||
panic(fmt.Sprintf("launch: integration %q has empty alias", spec.Name))
|
||||
}
|
||||
if canonical[key] {
|
||||
panic(fmt.Sprintf("launch: alias %q collides with canonical integration name", key))
|
||||
}
|
||||
if owner, exists := seenAliases[key]; exists {
|
||||
panic(fmt.Sprintf("launch: alias %q collides between %q and %q", key, owner, spec.Name))
|
||||
}
|
||||
seenAliases[key] = spec.Name
|
||||
integrationSpecsByName[key] = spec
|
||||
}
|
||||
}
|
||||
|
||||
orderSeen := make(map[string]bool, len(launcherIntegrationOrder))
|
||||
for _, name := range launcherIntegrationOrder {
|
||||
key := strings.ToLower(name)
|
||||
if orderSeen[key] {
|
||||
panic(fmt.Sprintf("launch: duplicate launcher order entry %q", key))
|
||||
}
|
||||
orderSeen[key] = true
|
||||
|
||||
spec, ok := integrationSpecsByName[key]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("launch: unknown launcher order entry %q", key))
|
||||
}
|
||||
if spec.Name != key {
|
||||
panic(fmt.Sprintf("launch: launcher order entry %q must use canonical name, not alias", key))
|
||||
}
|
||||
if spec.Hidden {
|
||||
panic(fmt.Sprintf("launch: hidden integration %q cannot appear in launcher order", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LookupIntegrationSpec resolves either a canonical integration name or alias to its spec.
|
||||
func LookupIntegrationSpec(name string) (*IntegrationSpec, error) {
|
||||
spec, ok := integrationSpecsByName[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// LookupIntegration resolves a registry name to the canonical key and runner.
|
||||
func LookupIntegration(name string) (string, Runner, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return spec.Name, spec.Runner, nil
|
||||
}
|
||||
|
||||
// ListVisibleIntegrationSpecs returns the canonical integrations that should appear in interactive UIs.
|
||||
func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
||||
visible := make([]IntegrationSpec, 0, len(integrationSpecs))
|
||||
for _, spec := range integrationSpecs {
|
||||
if spec.Hidden {
|
||||
continue
|
||||
}
|
||||
if supported, ok := spec.Runner.(SupportedIntegration); ok && supported.Supported() != nil {
|
||||
continue
|
||||
}
|
||||
if spec.Name == "pool" && poolsideGOOS == "windows" {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, *spec)
|
||||
}
|
||||
|
||||
orderRank := make(map[string]int, len(launcherIntegrationOrder))
|
||||
for i, name := range launcherIntegrationOrder {
|
||||
orderRank[name] = i + 1
|
||||
}
|
||||
|
||||
slices.SortFunc(visible, func(a, b IntegrationSpec) int {
|
||||
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||
if aRank > 0 && bRank > 0 {
|
||||
return aRank - bRank
|
||||
}
|
||||
if aRank > 0 {
|
||||
return -1
|
||||
}
|
||||
if bRank > 0 {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return visible
|
||||
}
|
||||
|
||||
// ListIntegrationInfos returns the registered integrations in launcher display order.
|
||||
func ListIntegrationInfos() []IntegrationInfo {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
infos := make([]IntegrationInfo, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
infos = append(infos, IntegrationInfo{
|
||||
Name: spec.Name,
|
||||
DisplayName: spec.Runner.String(),
|
||||
Description: spec.Description,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// IntegrationSelectionItems returns the sorted integration items shown by launcher selection UIs.
|
||||
func IntegrationSelectionItems() ([]ModelItem, error) {
|
||||
visible := ListVisibleIntegrationSpecs()
|
||||
if len(visible) == 0 {
|
||||
return nil, fmt.Errorf("no integrations available")
|
||||
}
|
||||
|
||||
items := make([]ModelItem, 0, len(visible))
|
||||
for _, spec := range visible {
|
||||
description := spec.Runner.String()
|
||||
if conn, err := loadStoredIntegrationConfig(spec.Name); err == nil && len(conn.Models) > 0 {
|
||||
description = fmt.Sprintf("%s (%s)", spec.Runner.String(), conn.Models[0])
|
||||
}
|
||||
items = append(items, ModelItem{Name: spec.Name, Description: description})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// IsIntegrationInstalled checks if an integration binary is installed.
|
||||
func IsIntegrationInstalled(name string) bool {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Ollama couldn't find integration %q, so it'll show up as not installed.\n", name)
|
||||
return false
|
||||
}
|
||||
return integration.installed
|
||||
}
|
||||
|
||||
// integration is resolved registry metadata used by launcher state and install checks.
|
||||
// It combines immutable registry spec data with computed runtime traits.
|
||||
type integration struct {
|
||||
spec *IntegrationSpec
|
||||
installed bool
|
||||
autoInstallable bool
|
||||
editor bool
|
||||
installHint string
|
||||
}
|
||||
|
||||
// integrationFor resolves an integration name into the canonical spec plus
|
||||
// derived launcher/install traits used across registry and launch flows.
|
||||
func integrationFor(name string) (integration, error) {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
return integration{}, err
|
||||
}
|
||||
|
||||
installed := true
|
||||
if spec.Install.CheckInstalled != nil {
|
||||
installed = spec.Install.CheckInstalled()
|
||||
}
|
||||
|
||||
_, editor := spec.Runner.(Editor)
|
||||
hint := ""
|
||||
if spec.Install.URL != "" {
|
||||
hint = "Install from " + hyperlink(spec.Install.URL, spec.Install.URL)
|
||||
} else if len(spec.Install.Command) > 0 {
|
||||
hint = "Install with: " + strings.Join(spec.Install.Command, " ")
|
||||
}
|
||||
|
||||
return integration{
|
||||
spec: spec,
|
||||
installed: installed,
|
||||
autoInstallable: spec.Install.EnsureInstalled != nil,
|
||||
editor: editor,
|
||||
installHint: hint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureIntegrationInstalled installs auto-installable integrations when missing.
|
||||
func EnsureIntegrationInstalled(name string, runner Runner) error {
|
||||
integration, err := integrationFor(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
|
||||
if supported, ok := runner.(SupportedIntegration); ok {
|
||||
if err := supported.Supported(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if integration.spec.Name == "pool" && poolsideGOOS == "windows" {
|
||||
return poolsideUnsupportedError()
|
||||
}
|
||||
|
||||
if integration.installed {
|
||||
return nil
|
||||
}
|
||||
if integration.autoInstallable {
|
||||
return integration.spec.Install.EnsureInstalled()
|
||||
}
|
||||
|
||||
switch {
|
||||
case integration.spec.Install.URL != "":
|
||||
return fmt.Errorf("%s is not installed, install from %s", integration.spec.Name, integration.spec.Install.URL)
|
||||
case len(integration.spec.Install.Command) > 0:
|
||||
return fmt.Errorf("%s is not installed, install with: %s", integration.spec.Name, strings.Join(integration.spec.Install.Command, " "))
|
||||
default:
|
||||
return fmt.Errorf("%s is not installed", runner)
|
||||
}
|
||||
}
|
||||
21
cmd/launch/registry_test_helpers_test.go
Normal file
21
cmd/launch/registry_test_helpers_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package launch
|
||||
|
||||
import "strings"
|
||||
|
||||
// OverrideIntegration replaces one registry entry's runner for tests and returns a restore function.
|
||||
func OverrideIntegration(name string, runner Runner) func() {
|
||||
spec, err := LookupIntegrationSpec(name)
|
||||
if err != nil {
|
||||
key := strings.ToLower(name)
|
||||
integrationSpecsByName[key] = &IntegrationSpec{Name: key, Runner: runner}
|
||||
return func() {
|
||||
delete(integrationSpecsByName, key)
|
||||
}
|
||||
}
|
||||
|
||||
original := spec.Runner
|
||||
spec.Runner = runner
|
||||
return func() {
|
||||
spec.Runner = original
|
||||
}
|
||||
}
|
||||
95
cmd/launch/runner_exec_only_test.go
Normal file
95
cmd/launch/runner_exec_only_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
binary string
|
||||
runner Runner
|
||||
checkPath func(home string) string
|
||||
}{
|
||||
{
|
||||
name: "droid",
|
||||
binary: "droid",
|
||||
runner: &Droid{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".factory", "settings.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opencode",
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cline",
|
||||
binary: "cline",
|
||||
runner: &Cline{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".cline", "data", "globalState.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pi",
|
||||
binary: "pi",
|
||||
runner: &Pi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pool",
|
||||
binary: "pool",
|
||||
runner: &Poolside{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".poolside", "config")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "kimi",
|
||||
binary: "kimi",
|
||||
runner: &Kimi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".kimi", "config.toml")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == "pool" && poolsideGOOS == "windows" {
|
||||
t.Skip("Poolside is intentionally unsupported on Windows")
|
||||
}
|
||||
|
||||
home := t.TempDir()
|
||||
setTestHome(t, home)
|
||||
|
||||
binDir := t.TempDir()
|
||||
writeFakeBinary(t, binDir, tt.binary)
|
||||
if tt.name == "pi" {
|
||||
writeFakeBinary(t, binDir, "npm")
|
||||
}
|
||||
if tt.name == "kimi" {
|
||||
writeFakeBinary(t, binDir, "curl")
|
||||
writeFakeBinary(t, binDir, "bash")
|
||||
}
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
configPath := tt.checkPath(home)
|
||||
if err := tt.runner.Run("llama3.2", nil, nil); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected Run to leave %s untouched, got err=%v", configPath, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
131
cmd/launch/selector_hooks.go
Normal file
131
cmd/launch/selector_hooks.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
var ErrCancelled = errors.New("cancelled")
|
||||
|
||||
// errCancelled is kept as an internal alias for existing call sites.
|
||||
var errCancelled = ErrCancelled
|
||||
|
||||
// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
|
||||
// When set, ConfirmPrompt delegates to it instead of using raw terminal I/O.
|
||||
var DefaultConfirmPrompt func(prompt string, options ConfirmOptions) (bool, error)
|
||||
|
||||
// ConfirmOptions customizes labels for confirmation prompts.
|
||||
type ConfirmOptions struct {
|
||||
YesLabel string
|
||||
NoLabel string
|
||||
}
|
||||
|
||||
// SingleSelector is a function type for single item selection.
|
||||
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||
type SingleSelector func(title string, items []SelectionItem, current string) (string, error)
|
||||
|
||||
// SingleSelectorWithUpdates is a single item selector that can receive refreshed item state while open.
|
||||
type SingleSelectorWithUpdates func(title string, items []SelectionItem, current string, updates <-chan []SelectionItem) (string, error)
|
||||
|
||||
// MultiSelector is a function type for multi item selection.
|
||||
type MultiSelector func(title string, items []SelectionItem, preChecked []string) ([]string, error)
|
||||
|
||||
// MultiSelectorWithUpdates is a multi item selector that can receive refreshed item state while open.
|
||||
type MultiSelectorWithUpdates func(title string, items []SelectionItem, preChecked []string, updates <-chan []SelectionItem) ([]string, error)
|
||||
|
||||
// DefaultSingleSelector is the default single-select implementation.
|
||||
var DefaultSingleSelector SingleSelector
|
||||
|
||||
// DefaultSingleSelectorWithUpdates is the default single-select implementation with live updates.
|
||||
var DefaultSingleSelectorWithUpdates SingleSelectorWithUpdates
|
||||
|
||||
// DefaultMultiSelector is the default multi-select implementation.
|
||||
var DefaultMultiSelector MultiSelector
|
||||
|
||||
// DefaultMultiSelectorWithUpdates is the default multi-select implementation with live updates.
|
||||
var DefaultMultiSelectorWithUpdates MultiSelectorWithUpdates
|
||||
|
||||
// DefaultSignIn provides a TUI-based sign-in flow.
|
||||
// When set, ensureAuth uses it instead of plain text prompts.
|
||||
// Returns the signed-in username or an error.
|
||||
var DefaultSignIn func(modelName, signInURL string) (string, error)
|
||||
|
||||
// DefaultUpgrade provides a TUI-based upgrade flow.
|
||||
// Returns the updated plan or an error.
|
||||
var DefaultUpgrade func(modelName, requiredPlan string) (string, error)
|
||||
|
||||
type launchConfirmPolicy struct {
|
||||
yes bool
|
||||
requireYesMessage bool
|
||||
}
|
||||
|
||||
var currentLaunchConfirmPolicy launchConfirmPolicy
|
||||
|
||||
func withLaunchConfirmPolicy(policy launchConfirmPolicy) func() {
|
||||
old := currentLaunchConfirmPolicy
|
||||
currentLaunchConfirmPolicy = policy
|
||||
return func() {
|
||||
currentLaunchConfirmPolicy = old
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmPrompt is the shared confirmation gate for launch flows (integration
|
||||
// edits, missing-model pulls, sign-in prompts, OpenClaw install/security, etc).
|
||||
// Behavior is controlled by currentLaunchConfirmPolicy, typically scoped by
|
||||
// withLaunchConfirmPolicy in LaunchCmd (e.g. auto-approve with --yes).
|
||||
func ConfirmPrompt(prompt string) (bool, error) {
|
||||
return ConfirmPromptWithOptions(prompt, ConfirmOptions{})
|
||||
}
|
||||
|
||||
// ConfirmPromptWithOptions is the shared confirmation gate for launch flows
|
||||
// that need custom yes/no labels in interactive UIs.
|
||||
func ConfirmPromptWithOptions(prompt string, options ConfirmOptions) (bool, error) {
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return true, nil
|
||||
}
|
||||
if currentLaunchConfirmPolicy.requireYesMessage {
|
||||
return false, fmt.Errorf("%s requires confirmation; re-run with --yes to continue", prompt)
|
||||
}
|
||||
|
||||
if DefaultConfirmPrompt != nil {
|
||||
return DefaultConfirmPrompt(prompt, options)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
|
||||
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
if _, err := os.Stdin.Read(buf); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch buf[0] {
|
||||
case 'Y', 'y', 13:
|
||||
fmt.Fprintf(os.Stderr, "yes\r\n")
|
||||
return true, nil
|
||||
case 'N', 'n', 27, 3:
|
||||
fmt.Fprintf(os.Stderr, "no\r\n")
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
112
cmd/launch/selector_test.go
Normal file
112
cmd/launch/selector_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrCancelled(t *testing.T) {
|
||||
t.Run("NotNil", func(t *testing.T) {
|
||||
if errCancelled == nil {
|
||||
t.Error("errCancelled should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Message", func(t *testing.T) {
|
||||
if errCancelled.Error() != "cancelled" {
|
||||
t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithLaunchConfirmPolicy_ScopesAndRestores(t *testing.T) {
|
||||
oldPolicy := currentLaunchConfirmPolicy
|
||||
oldHook := DefaultConfirmPrompt
|
||||
t.Cleanup(func() {
|
||||
currentLaunchConfirmPolicy = oldPolicy
|
||||
DefaultConfirmPrompt = oldHook
|
||||
})
|
||||
|
||||
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||
var hookCalls int
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
hookCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
restoreOuter := withLaunchConfirmPolicy(launchConfirmPolicy{requireYesMessage: true})
|
||||
restoreInner := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
|
||||
ok, err := ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected --yes policy to allow prompt, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected --yes policy to auto-accept prompt")
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected --yes to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreInner()
|
||||
|
||||
_, err = ConfirmPrompt("test prompt")
|
||||
if err == nil {
|
||||
t.Fatal("expected requireYesMessage policy to block prompt")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "re-run with --yes") {
|
||||
t.Fatalf("expected actionable --yes error, got: %v", err)
|
||||
}
|
||||
if hookCalls != 0 {
|
||||
t.Fatalf("expected blocking policy to skip hook, got %d hook calls", hookCalls)
|
||||
}
|
||||
|
||||
restoreOuter()
|
||||
|
||||
ok, err = ConfirmPrompt("test prompt")
|
||||
if err != nil {
|
||||
t.Fatalf("expected restored default behavior to use hook, got error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected hook to return true")
|
||||
}
|
||||
if hookCalls != 1 {
|
||||
t.Fatalf("expected one hook call after restore, got %d", hookCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmPromptWithOptions_DelegatesToOptionsHook(t *testing.T) {
|
||||
oldPolicy := currentLaunchConfirmPolicy
|
||||
oldHook := DefaultConfirmPrompt
|
||||
t.Cleanup(func() {
|
||||
currentLaunchConfirmPolicy = oldPolicy
|
||||
DefaultConfirmPrompt = oldHook
|
||||
})
|
||||
|
||||
currentLaunchConfirmPolicy = launchConfirmPolicy{}
|
||||
called := false
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
called = true
|
||||
if prompt != "Connect now?" {
|
||||
t.Fatalf("unexpected prompt: %q", prompt)
|
||||
}
|
||||
if options.YesLabel != "Yes" || options.NoLabel != "Set up later" {
|
||||
t.Fatalf("unexpected options: %+v", options)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
ok, err := ConfirmPromptWithOptions("Connect now?", ConfirmOptions{
|
||||
YesLabel: "Yes",
|
||||
NoLabel: "Set up later",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfirmPromptWithOptions() error = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected confirm to return true")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected options hook to be called")
|
||||
}
|
||||
}
|
||||
86
cmd/launch/test_config_helpers_test.go
Normal file
86
cmd/launch/test_config_helpers_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
var (
|
||||
integrations map[string]Runner
|
||||
integrationAliases map[string]bool
|
||||
integrationOrder = launcherIntegrationOrder
|
||||
)
|
||||
|
||||
func init() {
|
||||
integrations = buildTestIntegrations()
|
||||
integrationAliases = buildTestIntegrationAliases()
|
||||
}
|
||||
|
||||
func buildTestIntegrations() map[string]Runner {
|
||||
result := make(map[string]Runner, len(integrationSpecsByName))
|
||||
for name, spec := range integrationSpecsByName {
|
||||
result[strings.ToLower(name)] = spec.Runner
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildTestIntegrationAliases() map[string]bool {
|
||||
result := make(map[string]bool)
|
||||
for _, spec := range integrationSpecs {
|
||||
for _, alias := range spec.Aliases {
|
||||
result[strings.ToLower(alias)] = true
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func setTestHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
setLaunchTestHome(t, dir)
|
||||
}
|
||||
|
||||
func testLaunchModels(names ...string) []LaunchModel {
|
||||
return launchModelsFromNames(names)
|
||||
}
|
||||
|
||||
func SaveIntegration(appName string, models []string) error {
|
||||
return config.SaveIntegration(appName, models)
|
||||
}
|
||||
|
||||
func LoadIntegration(appName string) (*config.IntegrationConfig, error) {
|
||||
return config.LoadIntegration(appName)
|
||||
}
|
||||
|
||||
func SaveAliases(appName string, aliases map[string]string) error {
|
||||
return config.SaveAliases(appName, aliases)
|
||||
}
|
||||
|
||||
func LastModel() string {
|
||||
return config.LastModel()
|
||||
}
|
||||
|
||||
func SetLastModel(model string) error {
|
||||
return config.SetLastModel(model)
|
||||
}
|
||||
|
||||
func LastSelection() string {
|
||||
return config.LastSelection()
|
||||
}
|
||||
|
||||
func SetLastSelection(selection string) error {
|
||||
return config.SetLastSelection(selection)
|
||||
}
|
||||
|
||||
func IntegrationModel(appName string) string {
|
||||
return config.IntegrationModel(appName)
|
||||
}
|
||||
|
||||
func IntegrationModels(appName string) []string {
|
||||
return config.IntegrationModels(appName)
|
||||
}
|
||||
|
||||
func integrationOnboarded(appName string) error {
|
||||
return config.MarkIntegrationOnboarded(appName)
|
||||
}
|
||||
591
cmd/launch/vscode.go
Normal file
591
cmd/launch/vscode.go
Normal file
@@ -0,0 +1,591 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// VSCode implements Runner and Editor for Visual Studio Code integration.
|
||||
type VSCode struct{}
|
||||
|
||||
func (v *VSCode) String() string { return "Visual Studio Code" }
|
||||
|
||||
// findBinary returns the path/command to launch VS Code, or "" if not found.
|
||||
// It checks platform-specific locations only.
|
||||
func (v *VSCode) findBinary() string {
|
||||
var candidates []string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
candidates = []string{
|
||||
"/Applications/Visual Studio Code.app",
|
||||
}
|
||||
case "windows":
|
||||
if localAppData := os.Getenv("LOCALAPPDATA"); localAppData != "" {
|
||||
candidates = append(candidates, filepath.Join(localAppData, "Programs", "Microsoft VS Code", "bin", "code.cmd"))
|
||||
}
|
||||
default: // linux
|
||||
candidates = []string{
|
||||
"/usr/bin/code",
|
||||
"/snap/bin/code",
|
||||
}
|
||||
}
|
||||
for _, c := range candidates {
|
||||
if _, err := os.Stat(c); err == nil {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsRunning reports whether VS Code is currently running.
|
||||
// Each platform uses a pattern specific enough to avoid matching Cursor or
|
||||
// other VS Code forks.
|
||||
func (v *VSCode) IsRunning() bool {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
out, err := exec.Command("pgrep", "-f", "Visual Studio Code.app/Contents/MacOS/Code").Output()
|
||||
return err == nil && len(out) > 0
|
||||
case "windows":
|
||||
// Match VS Code by executable path to avoid matching Cursor or other forks.
|
||||
out, err := exec.Command("powershell", "-NoProfile", "-Command",
|
||||
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Select-Object -First 1`).Output()
|
||||
return err == nil && len(strings.TrimSpace(string(out))) > 0
|
||||
default:
|
||||
// Match VS Code specifically by its install path to avoid matching
|
||||
// Cursor (/cursor/) or other forks.
|
||||
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||
out, err := exec.Command("pgrep", "-f", pattern).Output()
|
||||
if err == nil && len(out) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Quit gracefully quits VS Code and waits for it to exit so that it flushes
|
||||
// its in-memory state back to the database.
|
||||
func (v *VSCode) Quit() {
|
||||
if !v.IsRunning() {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("osascript", "-e", `quit app "Visual Studio Code"`).Run()
|
||||
case "windows":
|
||||
// Kill VS Code by executable path to avoid killing Cursor or other forks.
|
||||
_ = exec.Command("powershell", "-NoProfile", "-Command",
|
||||
`Get-Process Code -ErrorAction SilentlyContinue | Where-Object { $_.Path -like '*Microsoft VS Code*' } | Stop-Process -Force`).Run()
|
||||
default:
|
||||
for _, pattern := range []string{"/usr/share/code/", "/snap/code/"} {
|
||||
_ = exec.Command("pkill", "-f", pattern).Run()
|
||||
}
|
||||
}
|
||||
// Wait for the process to fully exit and flush its state to disk
|
||||
// TODO(hoyyeva): update spinner to use bubble tea
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range 150 { // 150 ticks × 200ms = 30s timeout
|
||||
<-ticker.C
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mRestarting VS Code... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
if frame%5 == 0 { // check every ~1s
|
||||
if !v.IsRunning() {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
// Give VS Code a moment to finish writing its state DB
|
||||
time.Sleep(1 * time.Second)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
}
|
||||
|
||||
const (
|
||||
minCopilotChatVersion = "0.41.0"
|
||||
minVSCodeVersion = "1.113"
|
||||
)
|
||||
|
||||
func (v *VSCode) Run(model string, _ []LaunchModel, args []string) error {
|
||||
v.checkVSCodeVersion()
|
||||
v.checkCopilotChatVersion()
|
||||
|
||||
// Get all configured models (saved by the launcher framework before Run is called)
|
||||
models := []string{model}
|
||||
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil && len(cfg.Models) > 0 {
|
||||
models = cfg.Models
|
||||
}
|
||||
|
||||
// VS Code discovers models from ollama ls. Cloud models that pass Show
|
||||
// (the server knows about them) but aren't in ls need to be pulled to
|
||||
// register them so VS Code can find them.
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
v.ensureModelsRegistered(context.Background(), client, models)
|
||||
}
|
||||
|
||||
// Warn if the default model doesn't support tool calling
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(context.Background(), &api.ShowRequest{Model: models[0]}); err == nil {
|
||||
hasTools := false
|
||||
for _, c := range resp.Capabilities {
|
||||
if c == "tools" {
|
||||
hasTools = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTools {
|
||||
fmt.Fprintf(os.Stderr, "Note: %s does not support tool calling and may not appear in the Copilot Chat model picker.\n", models[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v.printModelAccessTip()
|
||||
|
||||
if v.IsRunning() {
|
||||
restart, err := ConfirmPrompt("Restart VS Code?")
|
||||
if err != nil {
|
||||
restart = false
|
||||
}
|
||||
if restart {
|
||||
v.Quit()
|
||||
if err := v.ShowInModelPicker(models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
v.FocusVSCode()
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\nTo get the latest model configuration, restart VS Code when you're ready.\n")
|
||||
}
|
||||
} else {
|
||||
if err := v.ShowInModelPicker(models); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not update VS Code model picker: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
v.FocusVSCode()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureModelsRegistered pulls models that the server knows about (Show succeeds)
|
||||
// but aren't in ollama ls yet. This is needed for cloud models so that VS Code
|
||||
// can discover them from the Ollama API.
|
||||
func (v *VSCode) ensureModelsRegistered(ctx context.Context, client *api.Client, models []string) {
|
||||
listed, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
registered := make(map[string]bool, len(listed.Models))
|
||||
for _, m := range listed.Models {
|
||||
registered[m.Name] = true
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if registered[model] {
|
||||
continue
|
||||
}
|
||||
// Also check without :latest suffix
|
||||
if !strings.Contains(model, ":") && registered[model+":latest"] {
|
||||
continue
|
||||
}
|
||||
if err := pullModel(ctx, client, model, false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not register model %s: %v%s\n", ansiYellow, model, err, ansiReset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FocusVSCode brings VS Code to the foreground.
|
||||
func (v *VSCode) FocusVSCode() {
|
||||
binary := v.findBinary()
|
||||
if binary == "" {
|
||||
return
|
||||
}
|
||||
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||
_ = exec.Command("open", "-a", binary).Run()
|
||||
} else {
|
||||
_ = exec.Command(binary).Start()
|
||||
}
|
||||
}
|
||||
|
||||
// printModelAccessTip shows instructions for finding Ollama models in VS Code.
|
||||
func (v *VSCode) printModelAccessTip() {
|
||||
fmt.Fprintf(os.Stderr, "\nTip: To use Ollama models, open Copilot Chat and click the model picker.\n")
|
||||
fmt.Fprintf(os.Stderr, " If you don't see your models, click \"Other models\" to find them.\n\n")
|
||||
}
|
||||
|
||||
func (v *VSCode) Paths() []string {
|
||||
if p := v.chatLanguageModelsPath(); fileExists(p) {
|
||||
return []string{p}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VSCode) Edit(models []LaunchModel) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write chatLanguageModels.json with Ollama vendor entry
|
||||
clmPath := v.chatLanguageModelsPath()
|
||||
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var entries []map[string]any
|
||||
if data, err := os.ReadFile(clmPath); err == nil {
|
||||
_ = json.Unmarshal(data, &entries)
|
||||
}
|
||||
|
||||
// Remove any existing Ollama entries, preserve others
|
||||
filtered := make([]map[string]any, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor != "ollama" {
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Add new Ollama entry
|
||||
filtered = append(filtered, map[string]any{
|
||||
"vendor": "ollama",
|
||||
"name": "Ollama",
|
||||
"url": envconfig.Host().String(),
|
||||
})
|
||||
|
||||
data, err := json.MarshalIndent(filtered, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(clmPath, data, "vscode"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clean up legacy settings from older Ollama integrations
|
||||
v.updateSettings()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VSCode) Models() []string {
|
||||
if !v.hasOllamaVendor() {
|
||||
return nil
|
||||
}
|
||||
if cfg, err := loadStoredIntegrationConfig("vscode"); err == nil {
|
||||
return cfg.Models
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasOllamaVendor checks if chatLanguageModels.json contains an Ollama vendor entry.
|
||||
func (v *VSCode) hasOllamaVendor() bool {
|
||||
data, err := os.ReadFile(v.chatLanguageModelsPath())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *VSCode) chatLanguageModelsPath() string {
|
||||
return v.vscodePath("chatLanguageModels.json")
|
||||
}
|
||||
|
||||
func (v *VSCode) settingsPath() string {
|
||||
return v.vscodePath("settings.json")
|
||||
}
|
||||
|
||||
// updateSettings cleans up legacy settings from older Ollama integrations.
|
||||
func (v *VSCode) updateSettings() {
|
||||
settingsPath := v.settingsPath()
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
if err := json.Unmarshal(data, &settings); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
for _, key := range []string{"github.copilot.chat.byok.ollamaEndpoint", "ollama.launch.configured"} {
|
||||
if _, ok := settings[key]; ok {
|
||||
delete(settings, key)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = fileutil.WriteWithBackup(settingsPath, updated, "vscode")
|
||||
}
|
||||
|
||||
func (v *VSCode) statePath() string {
|
||||
return v.vscodePath("globalStorage", "state.vscdb")
|
||||
}
|
||||
|
||||
// ShowInModelPicker ensures the given models are visible in VS Code's Copilot
|
||||
// Chat model picker. It sets the configured models to true in the picker
|
||||
// preferences so they appear in the dropdown. Models use the VS Code identifier
|
||||
// format "ollama/Ollama/<name>".
|
||||
func (v *VSCode) ShowInModelPicker(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dbPath := v.statePath()
|
||||
needsCreate := !fileExists(dbPath)
|
||||
if needsCreate {
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_busy_timeout=5000")
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening state database: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create the table if this is a fresh DB. Schema must match what VS Code creates.
|
||||
if needsCreate {
|
||||
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||
return fmt.Errorf("initializing state database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read existing preferences
|
||||
prefs := make(map[string]bool)
|
||||
var prefsJSON string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&prefsJSON); err == nil {
|
||||
_ = json.Unmarshal([]byte(prefsJSON), &prefs)
|
||||
}
|
||||
|
||||
// Build name→ID map from VS Code's cached model list.
|
||||
// VS Code uses numeric IDs like "ollama/Ollama/4", not "ollama/Ollama/kimi-k2.5:cloud".
|
||||
nameToID := make(map[string]string)
|
||||
var cacheJSON string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chat.cachedLanguageModels.v2'").Scan(&cacheJSON); err == nil {
|
||||
var cached []map[string]any
|
||||
if json.Unmarshal([]byte(cacheJSON), &cached) == nil {
|
||||
for _, entry := range cached {
|
||||
meta, _ := entry["metadata"].(map[string]any)
|
||||
if meta == nil {
|
||||
continue
|
||||
}
|
||||
if vendor, _ := meta["vendor"].(string); vendor == "ollama" {
|
||||
name, _ := meta["name"].(string)
|
||||
id, _ := entry["identifier"].(string)
|
||||
if name != "" && id != "" {
|
||||
nameToID[name] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ollama config is authoritative: always show configured models,
|
||||
// hide Ollama models that are no longer in the config.
|
||||
configuredIDs := make(map[string]bool)
|
||||
for _, m := range models {
|
||||
for _, id := range v.modelVSCodeIDs(m, nameToID) {
|
||||
prefs[id] = true
|
||||
configuredIDs[id] = true
|
||||
}
|
||||
}
|
||||
for id := range prefs {
|
||||
if strings.HasPrefix(id, "ollama/") && !configuredIDs[id] {
|
||||
prefs[id] = false
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(prefs)
|
||||
if _, err = db.Exec("INSERT OR REPLACE INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// modelVSCodeIDs returns all possible VS Code picker IDs for a model name.
|
||||
func (v *VSCode) modelVSCodeIDs(model string, nameToID map[string]string) []string {
|
||||
var ids []string
|
||||
if id, ok := nameToID[model]; ok {
|
||||
ids = append(ids, id)
|
||||
} else if !strings.Contains(model, ":") {
|
||||
if id, ok := nameToID[model+":latest"]; ok {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
ids = append(ids, "ollama/Ollama/"+model)
|
||||
if !strings.Contains(model, ":") {
|
||||
ids = append(ids, "ollama/Ollama/"+model+":latest")
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (v *VSCode) vscodePath(parts ...string) string {
|
||||
home, _ := os.UserHomeDir()
|
||||
var base string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
base = filepath.Join(home, "Library", "Application Support", "Code", "User")
|
||||
case "windows":
|
||||
base = filepath.Join(os.Getenv("APPDATA"), "Code", "User")
|
||||
default:
|
||||
base = filepath.Join(home, ".config", "Code", "User")
|
||||
}
|
||||
return filepath.Join(append([]string{base}, parts...)...)
|
||||
}
|
||||
|
||||
// checkVSCodeVersion warns if VS Code is older than minVSCodeVersion.
|
||||
func (v *VSCode) checkVSCodeVersion() {
|
||||
codeCLI := v.findCodeCLI()
|
||||
if codeCLI == "" {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := exec.Command(codeCLI, "--version").Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// "code --version" outputs: version\ncommit\narch
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
if len(lines) == 0 || lines[0] == "" {
|
||||
return
|
||||
}
|
||||
version := strings.TrimSpace(lines[0])
|
||||
|
||||
if compareVersions(version, minVSCodeVersion) < 0 {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: VS Code version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minVSCodeVersion, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Please update VS Code to the latest version.\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// checkCopilotChatVersion warns if the GitHub Copilot Chat extension is
|
||||
// missing or older than minCopilotChatVersion.
|
||||
func (v *VSCode) checkCopilotChatVersion() {
|
||||
codeCLI := v.findCodeCLI()
|
||||
if codeCLI == "" {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := exec.Command(codeCLI, "--list-extensions", "--show-versions").Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
installed, version := parseCopilotChatVersion(string(out))
|
||||
if !installed {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension is not installed%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Install it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Install\n\n")
|
||||
return
|
||||
}
|
||||
if compareVersions(version, minCopilotChatVersion) < 0 {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: GitHub Copilot Chat extension version (%s) is older than the recommended version (%s)%s\n", ansiYellow, version, minCopilotChatVersion, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "Please update it in VS Code: Extensions → search \"GitHub Copilot Chat\" → Update\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// findCodeCLI returns the path to the VS Code CLI for querying extensions.
|
||||
// On macOS, findBinary may return an .app bundle which can't run --list-extensions,
|
||||
// so this resolves to the actual CLI binary inside the bundle.
|
||||
func (v *VSCode) findCodeCLI() string {
|
||||
binary := v.findBinary()
|
||||
if binary == "" {
|
||||
return ""
|
||||
}
|
||||
if runtime.GOOS == "darwin" && strings.HasSuffix(binary, ".app") {
|
||||
bundleCLI := binary + "/Contents/Resources/app/bin/code"
|
||||
if _, err := os.Stat(bundleCLI); err == nil {
|
||||
return bundleCLI
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return binary
|
||||
}
|
||||
|
||||
// parseCopilotChatVersion extracts the version of the GitHub Copilot Chat
|
||||
// extension from "code --list-extensions --show-versions" output.
|
||||
func parseCopilotChatVersion(output string) (installed bool, version string) {
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
// Format: github.copilot-chat@0.40.1
|
||||
if !strings.HasPrefix(strings.ToLower(line), "github.copilot-chat@") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "@", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
return true, strings.TrimSpace(parts[1])
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// compareVersions compares two dot-separated version strings.
|
||||
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
|
||||
func compareVersions(a, b string) int {
|
||||
aParts := strings.Split(a, ".")
|
||||
bParts := strings.Split(b, ".")
|
||||
|
||||
maxLen := len(aParts)
|
||||
if len(bParts) > maxLen {
|
||||
maxLen = len(bParts)
|
||||
}
|
||||
|
||||
for i := range maxLen {
|
||||
var aNum, bNum int
|
||||
if i < len(aParts) {
|
||||
aNum, _ = strconv.Atoi(aParts[i])
|
||||
}
|
||||
if i < len(bParts) {
|
||||
bNum, _ = strconv.Atoi(bParts[i])
|
||||
}
|
||||
if aNum < bNum {
|
||||
return -1
|
||||
}
|
||||
if aNum > bNum {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
533
cmd/launch/vscode_test.go
Normal file
533
cmd/launch/vscode_test.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
)
|
||||
|
||||
func TestVSCodeIntegration(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := v.String(); got != "Visual Studio Code" {
|
||||
t.Errorf("String() = %q, want %q", got, "Visual Studio Code")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = v
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = v
|
||||
})
|
||||
}
|
||||
|
||||
func TestVSCodeEdit(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup string // initial chatLanguageModels.json content, empty means no file
|
||||
models []string
|
||||
validate func(t *testing.T, data []byte)
|
||||
}{
|
||||
{
|
||||
name: "fresh install",
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preserve other vendor entries",
|
||||
setup: `[{"vendor": "azure", "name": "Azure", "url": "https://example.com"}]`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
var entries []map[string]any
|
||||
json.Unmarshal(data, &entries)
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
// Check Azure entry preserved
|
||||
found := false
|
||||
for _, e := range entries {
|
||||
if v, _ := e["vendor"].(string); v == "azure" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("azure vendor entry was not preserved")
|
||||
}
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update existing ollama entry",
|
||||
setup: `[{"vendor": "ollama", "name": "Ollama", "url": "http://old:11434"}]`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
assertOllamaVendorConfigured(t, data)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty models is no-op",
|
||||
setup: `[{"vendor": "azure", "name": "Azure"}]`,
|
||||
models: []string{},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
if string(data) != `[{"vendor": "azure", "name": "Azure"}]` {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "corrupted JSON treated as empty",
|
||||
setup: `{corrupted json`,
|
||||
models: []string{"llama3.2"},
|
||||
validate: func(t *testing.T, data []byte) {
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
t.Errorf("result is not valid JSON: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.RemoveAll(filepath.Dir(clmPath))
|
||||
|
||||
if tt.setup != "" {
|
||||
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||
os.WriteFile(clmPath, []byte(tt.setup), 0o644)
|
||||
}
|
||||
|
||||
if err := v.Edit(launchModelsFromNames(tt.models)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(clmPath)
|
||||
tt.validate(t, data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVSCodeEditCleansUpOldSettings(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
|
||||
|
||||
// Create settings.json with old byok setting
|
||||
os.MkdirAll(filepath.Dir(settingsPath), 0o755)
|
||||
os.WriteFile(settingsPath, []byte(`{"github.copilot.chat.byok.ollamaEndpoint": "http://old:11434", "ollama.launch.configured": true, "editor.fontSize": 14}`), 0o644)
|
||||
|
||||
if err := v.Edit(testLaunchModels("llama3.2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify old settings were removed
|
||||
data, err := os.ReadFile(settingsPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var settings map[string]any
|
||||
json.Unmarshal(data, &settings)
|
||||
if _, ok := settings["github.copilot.chat.byok.ollamaEndpoint"]; ok {
|
||||
t.Error("github.copilot.chat.byok.ollamaEndpoint should have been removed")
|
||||
}
|
||||
if _, ok := settings["ollama.launch.configured"]; ok {
|
||||
t.Error("ollama.launch.configured should have been removed")
|
||||
}
|
||||
if settings["editor.fontSize"] != float64(14) {
|
||||
t.Error("editor.fontSize should have been preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVSCodeEdit_CreatesDistinctBackupsForManagedFiles(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
|
||||
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||
settingsPath := testVSCodePath(t, tmpDir, "settings.json")
|
||||
backupDir := fileutil.BackupDir()
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(clmPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clmOriginal := `[{"vendor":"ollama","name":"Ollama","url":"http://old:11434"}]`
|
||||
settingsOriginal := `{"github.copilot.chat.byok.ollamaEndpoint":"http://old:11434","ollama.launch.configured":true,"editor.fontSize":14}`
|
||||
if err := os.WriteFile(clmPath, []byte(clmOriginal), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(settingsPath, []byte(settingsOriginal), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := v.Edit(testLaunchModels("llama3.2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertBackupMatches := func(pattern, want string) {
|
||||
t.Helper()
|
||||
backups, err := filepath.Glob(filepath.Join(backupDir, pattern))
|
||||
if err != nil {
|
||||
t.Fatalf("glob %q failed: %v", pattern, err)
|
||||
}
|
||||
for _, backup := range backups {
|
||||
data, err := os.ReadFile(backup)
|
||||
if err == nil && string(data) == want {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("backup matching %q with expected content not found", pattern)
|
||||
}
|
||||
|
||||
assertBackupMatches(filepath.Join("vscode", "chatLanguageModels.json.*"), clmOriginal)
|
||||
assertBackupMatches(filepath.Join("vscode", "settings.json.*"), settingsOriginal)
|
||||
}
|
||||
|
||||
func TestVSCodePaths(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
clmPath := testVSCodePath(t, tmpDir, "chatLanguageModels.json")
|
||||
|
||||
t.Run("no file returns nil", func(t *testing.T) {
|
||||
os.Remove(clmPath)
|
||||
if paths := v.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing file returns path", func(t *testing.T) {
|
||||
os.MkdirAll(filepath.Dir(clmPath), 0o755)
|
||||
os.WriteFile(clmPath, []byte(`[]`), 0o644)
|
||||
|
||||
if paths := v.Paths(); len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testVSCodePath returns the expected VS Code config path for the given file in tests.
|
||||
func testVSCodePath(t *testing.T, tmpDir, filename string) string {
|
||||
t.Helper()
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(tmpDir, "Library", "Application Support", "Code", "User", filename)
|
||||
case "windows":
|
||||
t.Setenv("APPDATA", tmpDir)
|
||||
return filepath.Join(tmpDir, "Code", "User", filename)
|
||||
default:
|
||||
return filepath.Join(tmpDir, ".config", "Code", "User", filename)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOllamaVendorConfigured(t *testing.T, data []byte) {
|
||||
t.Helper()
|
||||
var entries []map[string]any
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if vendor, _ := entry["vendor"].(string); vendor == "ollama" {
|
||||
if name, _ := entry["name"].(string); name != "Ollama" {
|
||||
t.Errorf("expected name \"Ollama\", got %q", name)
|
||||
}
|
||||
if url, _ := entry["url"].(string); url == "" {
|
||||
t.Error("url not set")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("no ollama vendor entry found")
|
||||
}
|
||||
|
||||
func TestShowInModelPicker(t *testing.T) {
|
||||
v := &VSCode{}
|
||||
|
||||
// helper to create a state DB with optional seed data
|
||||
setupDB := func(t *testing.T, tmpDir string, seedPrefs map[string]bool, seedCache []map[string]any) string {
|
||||
t.Helper()
|
||||
dbDir := filepath.Join(tmpDir, "globalStorage")
|
||||
os.MkdirAll(dbDir, 0o755)
|
||||
dbPath := filepath.Join(dbDir, "state.vscdb")
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.Exec("CREATE TABLE ItemTable (key TEXT UNIQUE ON CONFLICT REPLACE, value BLOB)"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seedPrefs != nil {
|
||||
data, _ := json.Marshal(seedPrefs)
|
||||
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chatModelPickerPreferences', ?)", string(data))
|
||||
}
|
||||
if seedCache != nil {
|
||||
data, _ := json.Marshal(seedCache)
|
||||
db.Exec("INSERT INTO ItemTable (key, value) VALUES ('chat.cachedLanguageModels.v2', ?)", string(data))
|
||||
}
|
||||
return dbPath
|
||||
}
|
||||
|
||||
// helper to read prefs back from DB
|
||||
readPrefs := func(t *testing.T, dbPath string) map[string]bool {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var raw string
|
||||
if err := db.QueryRow("SELECT value FROM ItemTable WHERE key = 'chatModelPickerPreferences'").Scan(&raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prefs := make(map[string]bool)
|
||||
json.Unmarshal([]byte(raw), &prefs)
|
||||
return prefs
|
||||
}
|
||||
|
||||
t.Run("fresh DB creates table and shows models", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("APPDATA", tmpDir)
|
||||
}
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dbPath := testVSCodePath(t, tmpDir, filepath.Join("globalStorage", "state.vscdb"))
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be shown")
|
||||
}
|
||||
if !prefs["ollama/Ollama/llama3.2:latest"] {
|
||||
t.Error("expected llama3.2:latest to be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("configured models are shown", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, nil)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2", "qwen3:8b"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be shown")
|
||||
}
|
||||
if !prefs["ollama/Ollama/qwen3:8b"] {
|
||||
t.Error("expected qwen3:8b to be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removed models are hidden", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"ollama/Ollama/llama3.2": true,
|
||||
"ollama/Ollama/llama3.2:latest": true,
|
||||
"ollama/Ollama/mistral": true,
|
||||
"ollama/Ollama/mistral:latest": true,
|
||||
}, nil)
|
||||
|
||||
// Only configure llama3.2 — mistral should get hidden
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to stay shown")
|
||||
}
|
||||
if prefs["ollama/Ollama/mistral"] {
|
||||
t.Error("expected mistral to be hidden")
|
||||
}
|
||||
if prefs["ollama/Ollama/mistral:latest"] {
|
||||
t.Error("expected mistral:latest to be hidden")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-ollama prefs are preserved", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"copilot/gpt-4o": true,
|
||||
}, nil)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["copilot/gpt-4o"] {
|
||||
t.Error("expected copilot/gpt-4o to stay shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses cached numeric IDs when available", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
cache := []map[string]any{
|
||||
{
|
||||
"identifier": "ollama/Ollama/4",
|
||||
"metadata": map[string]any{"vendor": "ollama", "name": "llama3.2"},
|
||||
},
|
||||
}
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), nil, cache)
|
||||
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/4"] {
|
||||
t.Error("expected numeric ID ollama/Ollama/4 to be shown")
|
||||
}
|
||||
// Name-based fallback should also be set
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected name-based ID to also be shown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
err := v.ShowInModelPicker([]string{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("previously hidden model is re-shown when configured", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
dbPath := setupDB(t, testVSCodePath(t, tmpDir, ""), map[string]bool{
|
||||
"ollama/Ollama/llama3.2": false,
|
||||
"ollama/Ollama/llama3.2:latest": false,
|
||||
}, nil)
|
||||
|
||||
// Ollama config is authoritative — should override the hidden state
|
||||
err := v.ShowInModelPicker([]string{"llama3.2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prefs := readPrefs(t, dbPath)
|
||||
if !prefs["ollama/Ollama/llama3.2"] {
|
||||
t.Error("expected llama3.2 to be re-shown")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseCopilotChatVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
wantInstalled bool
|
||||
wantVersion string
|
||||
}{
|
||||
{
|
||||
name: "found among other extensions",
|
||||
output: "ms-python.python@2024.1.1\ngithub.copilot-chat@0.40.1\ngithub.copilot@1.200.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.40.1",
|
||||
},
|
||||
{
|
||||
name: "only extension",
|
||||
output: "GitHub.copilot-chat@0.41.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.41.0",
|
||||
},
|
||||
{
|
||||
name: "not installed",
|
||||
output: "ms-python.python@2024.1.1\ngithub.copilot@1.200.0\n",
|
||||
wantInstalled: false,
|
||||
},
|
||||
{
|
||||
name: "empty output",
|
||||
output: "",
|
||||
wantInstalled: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
output: "GitHub.Copilot-Chat@0.39.0\n",
|
||||
wantInstalled: true,
|
||||
wantVersion: "0.39.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
installed, version := parseCopilotChatVersion(tt.output)
|
||||
if installed != tt.wantInstalled {
|
||||
t.Errorf("installed = %v, want %v", installed, tt.wantInstalled)
|
||||
}
|
||||
if installed && version != tt.wantVersion {
|
||||
t.Errorf("version = %q, want %q", version, tt.wantVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
want int
|
||||
}{
|
||||
{"0.40.1", "0.40.1", 0},
|
||||
{"0.40.2", "0.40.1", 1},
|
||||
{"0.40.0", "0.40.1", -1},
|
||||
{"0.41.0", "0.40.1", 1},
|
||||
{"0.39.9", "0.40.1", -1},
|
||||
{"1.0.0", "0.40.1", 1},
|
||||
{"0.40", "0.40.1", -1},
|
||||
{"0.40.1.1", "0.40.1", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
|
||||
got := compareVersions(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
15
cmd/runner/main.go
Normal file
15
cmd/runner/main.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/runner"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := runner.Execute(os.Args[1:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
27
cmd/start.go
Normal file
27
cmd/start.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build darwin || windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func waitForServer(ctx context.Context, client *api.Client) error {
|
||||
// wait for the server to start
|
||||
timeout := time.After(5 * time.Second)
|
||||
tick := time.Tick(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New("timed out waiting for server to start")
|
||||
case <-tick:
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
33
cmd/start_darwin.go
Normal file
33
cmd/start_darwin.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
}
|
||||
14
cmd/start_default.go
Normal file
14
cmd/start_default.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
}
|
||||
112
cmd/start_windows.go
Normal file
112
cmd/start_windows.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
Installer = "OllamaSetup.exe"
|
||||
)
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
if len(isProcRunning(Installer)) > 0 {
|
||||
return fmt.Errorf("upgrade in progress...")
|
||||
}
|
||||
AppName := "ollama app.exe"
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
appExe := filepath.Join(filepath.Dir(exe), AppName)
|
||||
_, err = os.Stat(appExe)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Try the standard install location
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
appExe = filepath.Join(localAppData, "Ollama", AppName)
|
||||
_, err := os.Stat(appExe)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Finally look in the path
|
||||
appExe, err = exec.LookPath(AppName)
|
||||
if err != nil {
|
||||
return errors.New("could not locate ollama app")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("unable to start ollama app %w", err)
|
||||
}
|
||||
|
||||
if cmd.Process != nil {
|
||||
defer cmd.Process.Release() //nolint:errcheck
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
}
|
||||
|
||||
func isProcRunning(procName string) []uint32 {
|
||||
pids := make([]uint32, 2048)
|
||||
var ret uint32
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
if ret > uint32(len(pids)) {
|
||||
pids = make([]uint32, ret+10)
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if ret < uint32(len(pids)) {
|
||||
pids = pids[:ret]
|
||||
}
|
||||
var matches []uint32
|
||||
for _, pid := range pids {
|
||||
if pid == 0 {
|
||||
continue
|
||||
}
|
||||
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer windows.CloseHandle(hProcess)
|
||||
var module windows.Handle
|
||||
var cbNeeded uint32
|
||||
cb := (uint32)(unsafe.Sizeof(module))
|
||||
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
|
||||
continue
|
||||
}
|
||||
var sz uint32 = 1024 * 8
|
||||
moduleName := make([]uint16, sz)
|
||||
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
|
||||
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||
continue
|
||||
}
|
||||
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
|
||||
if strings.EqualFold(exeFile, procName) {
|
||||
matches = append(matches, pid)
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
133
cmd/tui/confirm.go
Normal file
133
cmd/tui/confirm.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
var (
|
||||
confirmActiveStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
confirmInactiveStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
)
|
||||
|
||||
type confirmModel struct {
|
||||
prompt string
|
||||
yesLabel string
|
||||
noLabel string
|
||||
yes bool
|
||||
confirmed bool
|
||||
cancelled bool
|
||||
width int
|
||||
}
|
||||
|
||||
type ConfirmOptions = launch.ConfirmOptions
|
||||
|
||||
func (m confirmModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc":
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
case "enter":
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
case "left":
|
||||
m.yes = true
|
||||
case "right":
|
||||
m.yes = false
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m confirmModel) View() string {
|
||||
if m.confirmed || m.cancelled {
|
||||
return ""
|
||||
}
|
||||
|
||||
var yesBtn, noBtn string
|
||||
yesLabel := m.yesLabel
|
||||
if yesLabel == "" {
|
||||
yesLabel = "Yes"
|
||||
}
|
||||
noLabel := m.noLabel
|
||||
if noLabel == "" {
|
||||
noLabel = "No"
|
||||
}
|
||||
if m.yes {
|
||||
yesBtn = confirmActiveStyle.Render(" " + yesLabel + " ")
|
||||
noBtn = confirmInactiveStyle.Render(" " + noLabel + " ")
|
||||
} else {
|
||||
yesBtn = confirmInactiveStyle.Render(" " + yesLabel + " ")
|
||||
noBtn = confirmActiveStyle.Render(" " + noLabel + " ")
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render(m.prompt) + "\n\n"
|
||||
s += " " + yesBtn + " " + noBtn + "\n\n"
|
||||
s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunConfirm shows a bubbletea yes/no confirmation prompt.
|
||||
// Returns true if the user confirmed, false if cancelled.
|
||||
func RunConfirm(prompt string) (bool, error) {
|
||||
return RunConfirmWithOptions(prompt, ConfirmOptions{})
|
||||
}
|
||||
|
||||
// RunConfirmWithOptions shows a bubbletea yes/no confirmation prompt with
|
||||
// optional custom button labels.
|
||||
func RunConfirmWithOptions(prompt string, options ConfirmOptions) (bool, error) {
|
||||
yesLabel := options.YesLabel
|
||||
if yesLabel == "" {
|
||||
yesLabel = "Yes"
|
||||
}
|
||||
noLabel := options.NoLabel
|
||||
if noLabel == "" {
|
||||
noLabel = "No"
|
||||
}
|
||||
|
||||
m := confirmModel{
|
||||
prompt: prompt,
|
||||
yesLabel: yesLabel,
|
||||
noLabel: noLabel,
|
||||
yes: true, // default to yes
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error running confirm: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(confirmModel)
|
||||
if fm.cancelled {
|
||||
return false, ErrCancelled
|
||||
}
|
||||
|
||||
return fm.yes, nil
|
||||
}
|
||||
224
cmd/tui/confirm_test.go
Normal file
224
cmd/tui/confirm_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func TestConfirmModel_DefaultsToYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download test?", yes: true}
|
||||
if !m.yes {
|
||||
t.Error("should default to yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Download qwen3:8b?") {
|
||||
t.Error("should contain the prompt text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsButtons(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Yes") {
|
||||
t.Error("should contain Yes button")
|
||||
}
|
||||
if !strings.Contains(got, "No") {
|
||||
t.Error("should contain No button")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsCustomButtons(t *testing.T) {
|
||||
m := confirmModel{
|
||||
prompt: "Connect a messaging app now?",
|
||||
yesLabel: "Yes",
|
||||
noLabel: "Set up later",
|
||||
yes: true,
|
||||
}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "Yes") {
|
||||
t.Error("should contain custom yes button")
|
||||
}
|
||||
if !strings.Contains(got, "Set up later") {
|
||||
t.Error("should contain custom no button")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ContainsHelp(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
got := m.View()
|
||||
if !strings.Contains(got, "enter confirm") {
|
||||
t.Error("should contain help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", confirmed: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after confirmation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", cancelled: true}
|
||||
if m.View() != "" {
|
||||
t.Error("View should return empty string after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if !fm.yes {
|
||||
t.Error("enter with yes selected should keep yes=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.confirmed {
|
||||
t.Error("enter should set confirmed=true")
|
||||
}
|
||||
if fm.yes {
|
||||
t.Error("enter with no selected should keep yes=false")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_EscCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_CtrlCCancels(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_NDoesNothing(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.cancelled {
|
||||
t.Error("'n' should not cancel")
|
||||
}
|
||||
if fm.confirmed {
|
||||
t.Error("'n' should not confirm")
|
||||
}
|
||||
if cmd != nil {
|
||||
t.Error("'n' should not quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_YDoesNothing(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: false}
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.confirmed {
|
||||
t.Error("'y' should not confirm")
|
||||
}
|
||||
if fm.yes {
|
||||
t.Error("'y' should not change selection")
|
||||
}
|
||||
if cmd != nil {
|
||||
t.Error("'y' should not quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
// Right moves to No
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.yes {
|
||||
t.Error("right should move to No")
|
||||
}
|
||||
if fm.confirmed || fm.cancelled {
|
||||
t.Error("navigation should not confirm or cancel")
|
||||
}
|
||||
|
||||
// Left moves back to Yes
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyLeft})
|
||||
fm = updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("left should move to Yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_TabDoesNothing(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
fm := updated.(confirmModel)
|
||||
if !fm.yes {
|
||||
t.Error("tab should not change selection")
|
||||
}
|
||||
if fm.confirmed || fm.cancelled {
|
||||
t.Error("tab should not confirm or cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
fm := updated.(confirmModel)
|
||||
if fm.width != 100 {
|
||||
t.Errorf("expected width 100, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", width: 80}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
|
||||
if cmd == nil {
|
||||
t.Error("resize (width already set) should return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?"}
|
||||
_, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
|
||||
if cmd != nil {
|
||||
t.Error("initial WindowSizeMsg should not return a command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmModel_ViewMaxWidth(t *testing.T) {
|
||||
m := confirmModel{prompt: "Download?", yes: true, width: 40}
|
||||
got := m.View()
|
||||
// Just ensure it doesn't panic and returns content
|
||||
if got == "" {
|
||||
t.Error("View with width set should still return content")
|
||||
}
|
||||
}
|
||||
993
cmd/tui/selector.go
Normal file
993
cmd/tui/selector.go
Normal file
@@ -0,0 +1,993 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
var (
|
||||
selectorTitleStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
selectorItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(4)
|
||||
|
||||
selectorSelectedItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
selectorDescStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorDescLineStyle = selectorDescStyle.
|
||||
PaddingLeft(6)
|
||||
|
||||
selectorFilterStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
selectorInputStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
|
||||
|
||||
selectorDefaultTagStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
selectorHelpStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
|
||||
|
||||
selectorMoreStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(6).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
|
||||
sectionHeaderStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
Bold(true).
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
|
||||
)
|
||||
|
||||
const maxSelectorItems = 10
|
||||
|
||||
// ErrCancelled is returned when the user cancels the selection.
|
||||
var ErrCancelled = launch.ErrCancelled
|
||||
|
||||
type SelectItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
AvailabilityBadge string
|
||||
}
|
||||
|
||||
type selectorItemsUpdatedMsg struct {
|
||||
items []SelectItem
|
||||
}
|
||||
|
||||
func waitForSelectorItems(updates <-chan []SelectItem) tea.Cmd {
|
||||
if updates == nil {
|
||||
return nil
|
||||
}
|
||||
return func() tea.Msg {
|
||||
items, ok := <-updates
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return selectorItemsUpdatedMsg{items: items}
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertItems converts launch.SelectionItem slice to SelectItem slice.
|
||||
func ConvertItems(items []launch.SelectionItem) []SelectItem {
|
||||
out := make([]SelectItem, len(items))
|
||||
for i, item := range items {
|
||||
out[i] = SelectItem{
|
||||
Name: item.Name,
|
||||
Description: item.Description,
|
||||
Recommended: item.Recommended,
|
||||
AvailabilityBadge: item.AvailabilityBadge,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ReorderItems returns a copy with recommended items first, then non-recommended,
|
||||
// preserving relative order within each group. This ensures the data order matches
|
||||
// the visual section layout (Recommended / More).
|
||||
func ReorderItems(items []SelectItem) []SelectItem {
|
||||
var rec, other []SelectItem
|
||||
for _, item := range items {
|
||||
if item.Recommended {
|
||||
rec = append(rec, item)
|
||||
} else {
|
||||
other = append(other, item)
|
||||
}
|
||||
}
|
||||
return append(rec, other...)
|
||||
}
|
||||
|
||||
// selectorModel is the bubbletea model for single selection.
|
||||
type selectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
updates <-chan []SelectItem
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
selected string
|
||||
cancelled bool
|
||||
helpText string
|
||||
width int
|
||||
}
|
||||
|
||||
func selectorModelWithCurrent(title string, items []SelectItem, current string) selectorModel {
|
||||
m := selectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
cursor: cursorForCurrent(items, current),
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
return m
|
||||
}
|
||||
|
||||
func currentItemName(items []SelectItem, cursor int) string {
|
||||
if cursor < 0 || cursor >= len(items) {
|
||||
return ""
|
||||
}
|
||||
return items[cursor].Name
|
||||
}
|
||||
|
||||
func cursorForItemName(items []SelectItem, name string, fallback int) int {
|
||||
if len(items) == 0 {
|
||||
return 0
|
||||
}
|
||||
if name != "" {
|
||||
for i, item := range items {
|
||||
if item.Name == name {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
if fallback < 0 {
|
||||
return 0
|
||||
}
|
||||
if fallback >= len(items) {
|
||||
return len(items) - 1
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (m selectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m selectorModel) Init() tea.Cmd {
|
||||
return waitForSelectorItems(m.updates)
|
||||
}
|
||||
|
||||
// otherStart returns the index of the first non-recommended item in the filtered list.
|
||||
// When filtering, all items scroll together so this returns 0.
|
||||
func (m selectorModel) otherStart() int {
|
||||
if m.filter != "" {
|
||||
return 0
|
||||
}
|
||||
filtered := m.filteredItems()
|
||||
for i, item := range filtered {
|
||||
if !item.Recommended {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(filtered)
|
||||
}
|
||||
|
||||
// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
|
||||
// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
|
||||
// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
|
||||
func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
|
||||
filtered := m.filteredItems()
|
||||
otherStart := m.otherStart()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
m.updateScroll(otherStart)
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
m.updateScroll(otherStart)
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// updateScroll adjusts scrollOffset based on cursor position.
|
||||
// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
|
||||
// When filtering, it's relative to the full filtered list.
|
||||
func (m *selectorModel) updateScroll(otherStart int) {
|
||||
if m.filter != "" {
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in recommended section — reset "More" scroll to top
|
||||
if m.cursor < otherStart {
|
||||
m.scrollOffset = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Cursor is in "More" section — scroll relative to others
|
||||
posInOthers := m.cursor - otherStart
|
||||
maxOthers := maxSelectorItems - otherStart
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
if posInOthers < m.scrollOffset {
|
||||
m.scrollOffset = posInOthers
|
||||
}
|
||||
if posInOthers >= m.scrollOffset+maxOthers {
|
||||
m.scrollOffset = posInOthers - maxOthers + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case selectorItemsUpdatedMsg:
|
||||
current := currentItemName(m.filteredItems(), m.cursor)
|
||||
m.items = msg.items
|
||||
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
|
||||
m.updateScroll(m.otherStart())
|
||||
return m, waitForSelectorItems(m.updates)
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyLeft:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyEnter:
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.selected = filtered[m.cursor].Name
|
||||
}
|
||||
return m, tea.Quit
|
||||
|
||||
default:
|
||||
m.updateNavigation(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func cursorItemSuffix(item SelectItem) string {
|
||||
if item.AvailabilityBadge == "" {
|
||||
return ""
|
||||
}
|
||||
return " " + selectorDefaultTagStyle.Render("("+item.AvailabilityBadge+")")
|
||||
}
|
||||
|
||||
func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// renderContent renders the selector content (title, items, help text) without
|
||||
// checking the cancelled/selected state. This is used by both View() (standalone mode)
|
||||
// and by the TUI modal which embeds a selectorModel.
|
||||
func (m selectorModel) renderContent() string {
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else if m.filter != "" {
|
||||
s.WriteString(sectionHeaderStyle.Render("Top Results"))
|
||||
s.WriteString("\n")
|
||||
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
// Split into pinned recommended and scrollable others
|
||||
var recItems, otherItems []int
|
||||
for i, item := range filtered {
|
||||
if item.Recommended {
|
||||
recItems = append(recItems, i)
|
||||
} else {
|
||||
otherItems = append(otherItems, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Always render all recommended items (pinned)
|
||||
if len(recItems) > 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
m.renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(otherItems) > 0 {
|
||||
s.WriteString("\n")
|
||||
s.WriteString(sectionHeaderStyle.Render("More"))
|
||||
s.WriteString("\n")
|
||||
|
||||
maxOthers := maxSelectorItems - len(recItems)
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
displayCount := min(len(otherItems), maxOthers)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
help := "↑/↓ navigate • enter select • ← back"
|
||||
if m.helpText != "" {
|
||||
help = m.helpText
|
||||
}
|
||||
s.WriteString(selectorHelpStyle.Render(help))
|
||||
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func (m selectorModel) View() string {
|
||||
if m.cancelled || m.selected != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := m.renderContent()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||
func cursorForCurrent(items []SelectItem, current string) int {
|
||||
if current == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Prefer exact name matches before tag-prefix fallback so "qwen3.5" does not
|
||||
// incorrectly select "qwen3.5:cloud" (and vice versa) based on list order.
|
||||
for i, item := range items {
|
||||
if item.Name == current {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
if strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||
return SelectSingleWithUpdates(title, items, current, nil)
|
||||
}
|
||||
|
||||
func SelectSingleWithUpdates(title string, items []SelectItem, current string, updates <-chan []SelectItem) (string, error) {
|
||||
if len(items) == 0 {
|
||||
return "", fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := selectorModelWithCurrent(title, items, current)
|
||||
m.updates = updates
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(selectorModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.selected, nil
|
||||
}
|
||||
|
||||
// multiSelectorModel is the bubbletea model for multi selection.
|
||||
type multiSelectorModel struct {
|
||||
title string
|
||||
items []SelectItem
|
||||
updates <-chan []SelectItem
|
||||
itemIndex map[string]int
|
||||
filter string
|
||||
cursor int
|
||||
scrollOffset int
|
||||
checked map[int]bool
|
||||
checkOrder []int
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
m := multiSelectorModel{
|
||||
title: title,
|
||||
items: items,
|
||||
itemIndex: make(map[string]int, len(items)),
|
||||
checked: make(map[int]bool),
|
||||
}
|
||||
|
||||
for i, item := range items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) rebuildItemIndex() {
|
||||
m.itemIndex = make(map[string]int, len(m.items))
|
||||
for i, item := range m.items {
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) replaceItems(items []SelectItem) {
|
||||
current := currentItemName(m.filteredItems(), m.cursor)
|
||||
checkedNames := make([]string, 0, len(m.checkOrder))
|
||||
for _, idx := range m.checkOrder {
|
||||
if idx >= 0 && idx < len(m.items) {
|
||||
checkedNames = append(checkedNames, m.items[idx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
m.items = items
|
||||
m.rebuildItemIndex()
|
||||
m.checked = make(map[int]bool, len(checkedNames))
|
||||
m.checkOrder = nil
|
||||
for _, name := range checkedNames {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
m.cursor = cursorForItemName(m.filteredItems(), current, m.cursor)
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) filteredItems() []SelectItem {
|
||||
if m.filter == "" {
|
||||
return m.items
|
||||
}
|
||||
filterLower := strings.ToLower(m.filter)
|
||||
var result []SelectItem
|
||||
for _, item := range m.items {
|
||||
if strings.Contains(strings.ToLower(item.Name), filterLower) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// otherStart returns the index of the first non-recommended item in the filtered list.
|
||||
func (m multiSelectorModel) otherStart() int {
|
||||
if m.filter != "" {
|
||||
return 0
|
||||
}
|
||||
filtered := m.filteredItems()
|
||||
for i, item := range filtered {
|
||||
if !item.Recommended {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(filtered)
|
||||
}
|
||||
|
||||
// updateScroll adjusts scrollOffset for section-based scrolling (matches single-select).
|
||||
func (m *multiSelectorModel) updateScroll(otherStart int) {
|
||||
if m.filter != "" {
|
||||
if m.cursor < m.scrollOffset {
|
||||
m.scrollOffset = m.cursor
|
||||
}
|
||||
if m.cursor >= m.scrollOffset+maxSelectorItems {
|
||||
m.scrollOffset = m.cursor - maxSelectorItems + 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if m.cursor < otherStart {
|
||||
m.scrollOffset = 0
|
||||
return
|
||||
}
|
||||
|
||||
posInOthers := m.cursor - otherStart
|
||||
maxOthers := maxSelectorItems - otherStart
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
if posInOthers < m.scrollOffset {
|
||||
m.scrollOffset = posInOthers
|
||||
}
|
||||
if posInOthers >= m.scrollOffset+maxOthers {
|
||||
m.scrollOffset = posInOthers - maxOthers + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiSelectorModel) toggleItem() {
|
||||
filtered := m.filteredItems()
|
||||
if len(filtered) == 0 || m.cursor >= len(filtered) {
|
||||
return
|
||||
}
|
||||
|
||||
item := filtered[m.cursor]
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
if m.checked[origIdx] {
|
||||
wasDefault := len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx
|
||||
delete(m.checked, origIdx)
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == origIdx {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if wasDefault {
|
||||
// When removing the default, pick the nearest checked model above it
|
||||
// (or below if none above) so default fallback follows list order.
|
||||
newDefault := -1
|
||||
for i := origIdx - 1; i >= 0; i-- {
|
||||
if m.checked[i] {
|
||||
newDefault = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if newDefault == -1 {
|
||||
for i := origIdx + 1; i < len(m.items); i++ {
|
||||
if m.checked[i] {
|
||||
newDefault = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if newDefault != -1 {
|
||||
for i, idx := range m.checkOrder {
|
||||
if idx == newDefault {
|
||||
m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
m.checkOrder = append(m.checkOrder, newDefault)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
m.checked[origIdx] = true
|
||||
m.checkOrder = append(m.checkOrder, origIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) selectedCount() int {
|
||||
return len(m.checkOrder)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Init() tea.Cmd {
|
||||
return waitForSelectorItems(m.updates)
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case selectorItemsUpdatedMsg:
|
||||
m.replaceItems(msg.items)
|
||||
return m, waitForSelectorItems(m.updates)
|
||||
|
||||
case tea.KeyMsg:
|
||||
filtered := m.filteredItems()
|
||||
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyLeft:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
|
||||
case tea.KeyDown:
|
||||
if m.cursor < len(filtered)-1 {
|
||||
m.cursor++
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
|
||||
case tea.KeyPgUp:
|
||||
m.cursor -= maxSelectorItems
|
||||
if m.cursor < 0 {
|
||||
m.cursor = 0
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
|
||||
case tea.KeyPgDown:
|
||||
m.cursor += maxSelectorItems
|
||||
if m.cursor >= len(filtered) {
|
||||
m.cursor = len(filtered) - 1
|
||||
}
|
||||
m.updateScroll(m.otherStart())
|
||||
|
||||
case tea.KeyBackspace:
|
||||
if len(m.filter) > 0 {
|
||||
m.filter = m.filter[:len(m.filter)-1]
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
|
||||
case tea.KeyRunes:
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
m.scrollOffset = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
var check string
|
||||
if m.checked[origIdx] {
|
||||
check = "[x] "
|
||||
} else {
|
||||
check = "[ ] "
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name))
|
||||
s.WriteString(cursorItemSuffix(item))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(check + item.Name))
|
||||
}
|
||||
s.WriteString(suffix)
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) View() string {
|
||||
if m.cancelled || m.confirmed {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
s.WriteString(" ")
|
||||
if m.filter == "" {
|
||||
s.WriteString(selectorFilterStyle.Render("Type to filter..."))
|
||||
} else {
|
||||
s.WriteString(selectorInputStyle.Render(m.filter))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
filtered := m.filteredItems()
|
||||
|
||||
if len(filtered) == 0 {
|
||||
s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
|
||||
s.WriteString("\n")
|
||||
} else if m.filter != "" {
|
||||
// Filtering: flat scroll through all matches
|
||||
displayCount := min(len(filtered), maxSelectorItems)
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
// Split into pinned recommended and scrollable others (matches single-select layout)
|
||||
var recItems, otherItems []int
|
||||
for i, item := range filtered {
|
||||
if item.Recommended {
|
||||
recItems = append(recItems, i)
|
||||
} else {
|
||||
otherItems = append(otherItems, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Always render all recommended items (pinned)
|
||||
if len(recItems) > 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(otherItems) > 0 {
|
||||
s.WriteString("\n")
|
||||
s.WriteString(sectionHeaderStyle.Render("More"))
|
||||
s.WriteString("\n")
|
||||
|
||||
maxOthers := maxSelectorItems - len(recItems)
|
||||
if maxOthers < 3 {
|
||||
maxOthers = 3
|
||||
}
|
||||
displayCount := min(len(otherItems), maxOthers)
|
||||
|
||||
for i := range displayCount {
|
||||
idx := m.scrollOffset + i
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
count := m.selectedCount()
|
||||
if !m.multi {
|
||||
if count > 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press tab to edit", count)))
|
||||
s.WriteString("\n\n")
|
||||
}
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • ← back"))
|
||||
} else {
|
||||
if count == 0 {
|
||||
s.WriteString(sectionHeaderStyle.Render("Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(sectionHeaderStyle.Render(fmt.Sprintf("%d models selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • ← back"))
|
||||
}
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
|
||||
return SelectMultipleWithUpdates(title, items, preChecked, nil)
|
||||
}
|
||||
|
||||
func SelectMultipleWithUpdates(title string, items []SelectItem, preChecked []string, updates <-chan []SelectItem) ([]string, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no items to select from")
|
||||
}
|
||||
|
||||
m := newMultiSelectorModel(title, items, preChecked)
|
||||
m.updates = updates
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running selector: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
for _, idx := range fm.checkOrder {
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
1024
cmd/tui/selector_test.go
Normal file
1024
cmd/tui/selector_test.go
Normal file
File diff suppressed because it is too large
Load Diff
340
cmd/tui/signin.go
Normal file
340
cmd/tui/signin.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
type signInTickMsg struct{}
|
||||
|
||||
type signInCheckMsg struct {
|
||||
signedIn bool
|
||||
userName string
|
||||
}
|
||||
|
||||
type upgradeTickMsg struct{}
|
||||
|
||||
type upgradeCheckMsg struct {
|
||||
upgraded bool
|
||||
plan string
|
||||
err error
|
||||
}
|
||||
|
||||
type signInModel struct {
|
||||
modelName string
|
||||
signInURL string
|
||||
spinner int
|
||||
width int
|
||||
userName string
|
||||
cancelled bool
|
||||
}
|
||||
|
||||
type upgradeModel struct {
|
||||
modelName string
|
||||
requiredPlan string
|
||||
spinner int
|
||||
width int
|
||||
openNow bool
|
||||
polling bool
|
||||
plan string
|
||||
cancelled bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m signInModel) Init() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case signInTickMsg:
|
||||
m.spinner++
|
||||
if m.spinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
}),
|
||||
checkSignIn,
|
||||
)
|
||||
}
|
||||
return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return signInTickMsg{}
|
||||
})
|
||||
|
||||
case signInCheckMsg:
|
||||
if msg.signedIn {
|
||||
m.userName = msg.userName
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m signInModel) View() string {
|
||||
if m.userName != "" {
|
||||
return ""
|
||||
}
|
||||
return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
|
||||
}
|
||||
|
||||
func (m upgradeModel) Init() tea.Cmd {
|
||||
if m.polling {
|
||||
return upgradeTickCmd()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m upgradeModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
wasSet := m.width > 0
|
||||
m.width = msg.Width
|
||||
if wasSet {
|
||||
return m, tea.EnterAltScreen
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
case tea.KeyLeft:
|
||||
if !m.polling {
|
||||
m.openNow = true
|
||||
}
|
||||
case tea.KeyRight:
|
||||
if !m.polling {
|
||||
m.openNow = false
|
||||
}
|
||||
case tea.KeyEnter:
|
||||
if !m.polling {
|
||||
if !m.openNow {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
launch.OpenBrowser(launch.DefaultUpgradeURL)
|
||||
m.polling = true
|
||||
return m, upgradeTickCmd()
|
||||
}
|
||||
}
|
||||
|
||||
case upgradeTickMsg:
|
||||
if !m.polling {
|
||||
return m, nil
|
||||
}
|
||||
m.spinner++
|
||||
if m.spinner%5 == 0 {
|
||||
return m, tea.Batch(
|
||||
upgradeTickCmd(),
|
||||
checkUpgrade(m.requiredPlan),
|
||||
)
|
||||
}
|
||||
return m, upgradeTickCmd()
|
||||
|
||||
case upgradeCheckMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
return m, tea.Quit
|
||||
}
|
||||
if msg.upgraded {
|
||||
m.plan = msg.plan
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m upgradeModel) View() string {
|
||||
if m.plan != "" {
|
||||
return ""
|
||||
}
|
||||
if m.err != nil {
|
||||
return ""
|
||||
}
|
||||
return renderUpgrade(m.modelName, m.spinner, m.width, m.polling, m.openNow)
|
||||
}
|
||||
|
||||
func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
||||
urlColor := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("117"))
|
||||
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
|
||||
if width > 4 {
|
||||
urlWrap = urlWrap.Width(width - 4)
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
frame + " Waiting for sign in to complete..."))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func upgradeTickCmd() tea.Cmd {
|
||||
return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return upgradeTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func renderUpgrade(modelName string, spinner, width int, polling, openNow bool) string {
|
||||
spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
|
||||
frame := spinnerFrames[spinner%len(spinnerFrames)]
|
||||
|
||||
urlColor := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("117"))
|
||||
urlWrap := lipgloss.NewStyle().PaddingLeft(2)
|
||||
if width > 4 {
|
||||
urlWrap = urlWrap.Width(width - 4)
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, upgrade your Ollama plan.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(urlColor.Render(launch.DefaultUpgradeURL)))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
if !polling {
|
||||
var yesBtn, noBtn string
|
||||
if openNow {
|
||||
yesBtn = confirmActiveStyle.Render(" Yes ")
|
||||
noBtn = confirmInactiveStyle.Render(" No ")
|
||||
} else {
|
||||
yesBtn = confirmInactiveStyle.Render(" Yes ")
|
||||
noBtn = confirmActiveStyle.Render(" No ")
|
||||
}
|
||||
|
||||
s.WriteString("Open now?\n")
|
||||
s.WriteString(" " + yesBtn + " " + noBtn)
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel"))
|
||||
} else {
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
frame + " Waiting for upgrade to complete..."))
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("esc cancel"))
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
|
||||
}
|
||||
|
||||
func checkSignIn() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return signInCheckMsg{signedIn: true, userName: user.Name}
|
||||
}
|
||||
return signInCheckMsg{signedIn: false}
|
||||
}
|
||||
|
||||
func checkUpgrade(requiredPlan string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
user, err := client.Whoami(ctx)
|
||||
if err != nil {
|
||||
return upgradeCheckMsg{err: launch.ErrPlanVerificationUnavailable}
|
||||
}
|
||||
if err == nil && user != nil && user.Name != "" && launch.PlanSatisfies(user.Plan, requiredPlan) {
|
||||
return upgradeCheckMsg{upgraded: true, plan: user.Plan}
|
||||
}
|
||||
return upgradeCheckMsg{upgraded: false}
|
||||
}
|
||||
}
|
||||
|
||||
// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
|
||||
func RunSignIn(modelName, signInURL string) (string, error) {
|
||||
launch.OpenBrowser(signInURL)
|
||||
|
||||
m := signInModel{
|
||||
modelName: modelName,
|
||||
signInURL: signInURL,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running sign-in: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(signInModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
|
||||
return fm.userName, nil
|
||||
}
|
||||
|
||||
// RunUpgrade shows a bubbletea upgrade dialog and polls until the user's plan is updated or cancelled.
|
||||
func RunUpgrade(modelName, requiredPlan string) (string, error) {
|
||||
m := upgradeModel{
|
||||
modelName: modelName,
|
||||
requiredPlan: requiredPlan,
|
||||
openNow: true,
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running upgrade: %w", err)
|
||||
}
|
||||
|
||||
fm := finalModel.(upgradeModel)
|
||||
if fm.cancelled {
|
||||
return "", ErrCancelled
|
||||
}
|
||||
if fm.err != nil {
|
||||
return "", fm.err
|
||||
}
|
||||
|
||||
return fm.plan, nil
|
||||
}
|
||||
217
cmd/tui/signin_test.go
Normal file
217
cmd/tui/signin_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func TestRenderSignIn_ContainsModelName(t *testing.T) {
|
||||
got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
|
||||
if !strings.Contains(got, "glm-4.7:cloud") {
|
||||
t.Error("should contain model name")
|
||||
}
|
||||
if !strings.Contains(got, "please sign in") {
|
||||
t.Error("should contain sign-in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsURL(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
if !strings.Contains(got, url) {
|
||||
t.Errorf("should contain URL %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "Waiting for sign in to complete") {
|
||||
t.Error("should contain waiting message")
|
||||
}
|
||||
if !strings.Contains(got, "⠋") {
|
||||
t.Error("should contain first spinner frame at spinner=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
|
||||
got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
|
||||
if got0 == got1 {
|
||||
t.Error("different spinner values should produce different output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
if !strings.Contains(got, "esc cancel") {
|
||||
t.Error("should contain esc cancel help text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderUpgrade_AsksBeforeOpening(t *testing.T) {
|
||||
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, false, true)
|
||||
if !strings.Contains(got, "kimi-k2.6:cloud") {
|
||||
t.Error("should contain model name")
|
||||
}
|
||||
if !strings.Contains(got, launch.DefaultUpgradeURL) {
|
||||
t.Error("should contain upgrade URL")
|
||||
}
|
||||
if !strings.Contains(got, "Open now?") {
|
||||
t.Error("should ask before opening")
|
||||
}
|
||||
if !strings.Contains(got, "Yes") || !strings.Contains(got, "No") {
|
||||
t.Error("should show yes/no selector")
|
||||
}
|
||||
if strings.Contains(got, "Waiting for upgrade to complete") {
|
||||
t.Error("should not start waiting before open choice is confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderUpgrade_PollingShowsWaiting(t *testing.T) {
|
||||
got := renderUpgrade("kimi-k2.6:cloud", 0, 80, true, true)
|
||||
if !strings.Contains(got, "Waiting for upgrade to complete") {
|
||||
t.Error("should contain waiting message")
|
||||
}
|
||||
if strings.Contains(got, "Open now?") {
|
||||
t.Error("should not show open prompt while polling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_EscCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("esc should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("esc should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeModel_NoCancelsWithoutPolling(t *testing.T) {
|
||||
m := upgradeModel{
|
||||
modelName: "kimi-k2.6:cloud",
|
||||
requiredPlan: "pro",
|
||||
openNow: true,
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
fm := updated.(upgradeModel)
|
||||
if fm.openNow {
|
||||
t.Error("right should select no")
|
||||
}
|
||||
if fm.polling {
|
||||
t.Error("right should not start polling")
|
||||
}
|
||||
|
||||
updated, cmd := fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm = updated.(upgradeModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("enter on no should cancel")
|
||||
}
|
||||
if fm.polling {
|
||||
t.Error("enter on no should not start polling")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("enter on no should quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_CtrlCCancels(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
fm := updated.(signInModel)
|
||||
if !fm.cancelled {
|
||||
t.Error("ctrl+c should set cancelled=true")
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("ctrl+c should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInQuitsClean(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "alice" {
|
||||
t.Errorf("expected userName 'alice', got %q", fm.userName)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("successful sign-in should return tea.Quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_SignedInViewClears(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
userName: "alice",
|
||||
}
|
||||
|
||||
got := m.View()
|
||||
if got != "" {
|
||||
t.Errorf("View should return empty string after sign-in, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_NotSignedInContinues(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(signInCheckMsg{signedIn: false})
|
||||
fm := updated.(signInModel)
|
||||
if fm.userName != "" {
|
||||
t.Error("should not set userName when not signed in")
|
||||
}
|
||||
if fm.cancelled {
|
||||
t.Error("should not cancel when check returns not signed in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
fm := updated.(signInModel)
|
||||
if fm.width != 120 {
|
||||
t.Errorf("expected width 120, got %d", fm.width)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
|
||||
m := signInModel{
|
||||
modelName: "test:cloud",
|
||||
signInURL: "https://example.com",
|
||||
spinner: 0,
|
||||
}
|
||||
|
||||
updated, cmd := m.Update(signInTickMsg{})
|
||||
fm := updated.(signInModel)
|
||||
if fm.spinner != 1 {
|
||||
t.Errorf("expected spinner=1, got %d", fm.spinner)
|
||||
}
|
||||
if cmd == nil {
|
||||
t.Error("tick should return a command")
|
||||
}
|
||||
}
|
||||
393
cmd/tui/tui.go
Normal file
393
cmd/tui/tui.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
versionStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
menuItemStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2)
|
||||
|
||||
menuSelectedItemStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
|
||||
|
||||
menuDescStyle = selectorDescStyle.
|
||||
PaddingLeft(4)
|
||||
|
||||
greyedStyle = menuItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
greyedSelectedStyle = menuSelectedItemStyle.
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
|
||||
|
||||
modelStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
|
||||
|
||||
notInstalledStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
|
||||
Italic(true)
|
||||
)
|
||||
|
||||
type menuItem struct {
|
||||
title string
|
||||
description string
|
||||
integration string
|
||||
isRunModel bool
|
||||
isOthers bool
|
||||
}
|
||||
|
||||
const pinnedIntegrationCount = 4
|
||||
|
||||
var runModelMenuItem = menuItem{
|
||||
title: "Chat with a model",
|
||||
description: "Start an interactive chat with a model",
|
||||
isRunModel: true,
|
||||
}
|
||||
|
||||
var othersMenuItem = menuItem{
|
||||
title: "More...",
|
||||
description: "Show additional integrations",
|
||||
isOthers: true,
|
||||
}
|
||||
|
||||
type model struct {
|
||||
state *launch.LauncherState
|
||||
items []menuItem
|
||||
cursor int
|
||||
showOthers bool
|
||||
width int
|
||||
quitting bool
|
||||
selected bool
|
||||
action TUIAction
|
||||
}
|
||||
|
||||
func newModel(state *launch.LauncherState) model {
|
||||
m := model{
|
||||
state: state,
|
||||
}
|
||||
m.showOthers = shouldExpandOthers(state)
|
||||
m.items = buildMenuItems(state, m.showOthers)
|
||||
m.cursor = initialCursor(state, m.items)
|
||||
return m
|
||||
}
|
||||
|
||||
func shouldExpandOthers(state *launch.LauncherState) bool {
|
||||
if state == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range otherIntegrationItems(state) {
|
||||
if item.integration == state.LastSelection {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||
items := []menuItem{runModelMenuItem}
|
||||
items = append(items, pinnedIntegrationItems(state)...)
|
||||
|
||||
otherItems := otherIntegrationItems(state)
|
||||
switch {
|
||||
case showOthers:
|
||||
items = append(items, otherItems...)
|
||||
case len(otherItems) > 0:
|
||||
items = append(items, othersMenuItem)
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
||||
description := state.Description
|
||||
if description == "" {
|
||||
description = "Open " + state.DisplayName + " integration"
|
||||
}
|
||||
return menuItem{
|
||||
title: "Launch " + state.DisplayName,
|
||||
description: description,
|
||||
integration: state.Name,
|
||||
}
|
||||
}
|
||||
|
||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
ordered := orderedIntegrationItems(state)
|
||||
if len(ordered) <= pinnedIntegrationCount {
|
||||
return nil
|
||||
}
|
||||
return ordered[pinnedIntegrationCount:]
|
||||
}
|
||||
|
||||
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
ordered := orderedIntegrationItems(state)
|
||||
if len(ordered) <= pinnedIntegrationCount {
|
||||
return ordered
|
||||
}
|
||||
return ordered[:pinnedIntegrationCount]
|
||||
}
|
||||
|
||||
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
items := make([]menuItem, 0, len(state.Integrations))
|
||||
for _, info := range launch.ListIntegrationInfos() {
|
||||
integrationState, ok := state.Integrations[info.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
items = append(items, integrationMenuItem(integrationState))
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func primaryMenuItemCount(state *launch.LauncherState) int {
|
||||
return 1 + len(pinnedIntegrationItems(state))
|
||||
}
|
||||
|
||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||
if state == nil || state.LastSelection == "" {
|
||||
return 0
|
||||
}
|
||||
for i, item := range items {
|
||||
if state.LastSelection == "run" && item.isRunModel {
|
||||
return i
|
||||
}
|
||||
if item.integration == state.LastSelection {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q", "esc":
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
}
|
||||
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
|
||||
m.showOthers = false
|
||||
m.items = buildMenuItems(m.state, false)
|
||||
m.cursor = min(m.cursor, len(m.items)-1)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.items)-1 {
|
||||
m.cursor++
|
||||
}
|
||||
if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
|
||||
m.showOthers = true
|
||||
m.items = buildMenuItems(m.state, true)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "enter", " ":
|
||||
if m.selectableItem(m.items[m.cursor]) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(m.items[m.cursor], false)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case "right", "l":
|
||||
item := m.items[m.cursor]
|
||||
if item.isRunModel || m.changeableItem(item) {
|
||||
m.selected = true
|
||||
m.action = actionForMenuItem(item, true)
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) selectableItem(item menuItem) bool {
|
||||
if item.isRunModel {
|
||||
return true
|
||||
}
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Selectable
|
||||
}
|
||||
|
||||
func (m model) changeableItem(item menuItem) bool {
|
||||
if item.integration == "" || item.isOthers {
|
||||
return false
|
||||
}
|
||||
state, ok := m.state.Integrations[item.integration]
|
||||
return ok && state.Changeable
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
|
||||
|
||||
for i, item := range m.items {
|
||||
s += m.renderMenuItem(i, item)
|
||||
}
|
||||
|
||||
s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → configure • esc quit")
|
||||
|
||||
if m.width > 0 {
|
||||
return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m model) renderMenuItem(index int, item menuItem) string {
|
||||
cursor := ""
|
||||
style := menuItemStyle
|
||||
title := item.title
|
||||
description := item.description
|
||||
modelSuffix := ""
|
||||
|
||||
if m.cursor == index {
|
||||
cursor = "▸ "
|
||||
}
|
||||
|
||||
if item.isRunModel {
|
||||
if m.cursor == index && m.state.RunModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+m.state.RunModel+")")
|
||||
}
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else if item.isOthers {
|
||||
if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
} else {
|
||||
integrationState := m.state.Integrations[item.integration]
|
||||
if !integrationState.Selectable {
|
||||
if m.cursor == index {
|
||||
style = greyedSelectedStyle
|
||||
} else {
|
||||
style = greyedStyle
|
||||
}
|
||||
} else if m.cursor == index {
|
||||
style = menuSelectedItemStyle
|
||||
}
|
||||
|
||||
if m.cursor == index && integrationState.CurrentModel != "" {
|
||||
modelSuffix = " " + modelStyle.Render("("+integrationState.CurrentModel+")")
|
||||
}
|
||||
|
||||
if !integrationState.Installed {
|
||||
if integrationState.AutoInstallable {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
if m.cursor == index {
|
||||
if integrationState.AutoInstallable {
|
||||
description = "Press enter to install"
|
||||
} else if integrationState.InstallHint != "" {
|
||||
description = integrationState.InstallHint
|
||||
} else {
|
||||
description = "not installed"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return style.Render(cursor+title) + modelSuffix + "\n" + menuDescStyle.Render(description) + "\n\n"
|
||||
}
|
||||
|
||||
type TUIActionKind int
|
||||
|
||||
const (
|
||||
TUIActionNone TUIActionKind = iota
|
||||
TUIActionRunModel
|
||||
TUIActionLaunchIntegration
|
||||
)
|
||||
|
||||
type TUIAction struct {
|
||||
Kind TUIActionKind
|
||||
Integration string
|
||||
ForceConfigure bool
|
||||
}
|
||||
|
||||
func (a TUIAction) LastSelection() string {
|
||||
switch a.Kind {
|
||||
case TUIActionRunModel:
|
||||
return "run"
|
||||
case TUIActionLaunchIntegration:
|
||||
return a.Integration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a TUIAction) RunModelRequest() launch.RunModelRequest {
|
||||
return launch.RunModelRequest{ForcePicker: a.ForceConfigure}
|
||||
}
|
||||
|
||||
func (a TUIAction) IntegrationLaunchRequest() launch.IntegrationLaunchRequest {
|
||||
return launch.IntegrationLaunchRequest{
|
||||
Name: a.Integration,
|
||||
ForceConfigure: a.ForceConfigure,
|
||||
}
|
||||
}
|
||||
|
||||
func actionForMenuItem(item menuItem, forceConfigure bool) TUIAction {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
return TUIAction{Kind: TUIActionRunModel, ForceConfigure: forceConfigure}
|
||||
case item.integration != "":
|
||||
return TUIAction{Kind: TUIActionLaunchIntegration, Integration: item.integration, ForceConfigure: forceConfigure}
|
||||
default:
|
||||
return TUIAction{Kind: TUIActionNone}
|
||||
}
|
||||
}
|
||||
|
||||
func RunMenu(state *launch.LauncherState) (TUIAction, error) {
|
||||
menu := newModel(state)
|
||||
program := tea.NewProgram(menu)
|
||||
|
||||
finalModel, err := program.Run()
|
||||
if err != nil {
|
||||
return TUIAction{Kind: TUIActionNone}, fmt.Errorf("error running TUI: %w", err)
|
||||
}
|
||||
|
||||
finalMenu := finalModel.(model)
|
||||
if !finalMenu.selected {
|
||||
return TUIAction{Kind: TUIActionNone}, nil
|
||||
}
|
||||
|
||||
return finalMenu.action, nil
|
||||
}
|
||||
296
cmd/tui/tui_test.go
Normal file
296
cmd/tui/tui_test.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/cmd/launch"
|
||||
)
|
||||
|
||||
func launcherTestState() *launch.LauncherState {
|
||||
return &launch.LauncherState{
|
||||
LastSelection: "run",
|
||||
RunModel: "qwen3:8b",
|
||||
Integrations: map[string]launch.LauncherIntegrationState{
|
||||
"claude": {
|
||||
Name: "claude",
|
||||
DisplayName: "Claude Code",
|
||||
Description: "Anthropic's coding tool with subagents",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
CurrentModel: "glm-5:cloud",
|
||||
},
|
||||
"codex": {
|
||||
Name: "codex",
|
||||
DisplayName: "Codex",
|
||||
Description: "OpenAI's open-source coding agent",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"codex-app": {
|
||||
Name: "codex-app",
|
||||
DisplayName: "Codex App",
|
||||
Description: "An AI agent you can delegate real work to, by OpenAI",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"openclaw": {
|
||||
Name: "openclaw",
|
||||
DisplayName: "OpenClaw",
|
||||
Description: "Personal AI with 100+ skills",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
AutoInstallable: true,
|
||||
},
|
||||
"opencode": {
|
||||
Name: "opencode",
|
||||
DisplayName: "OpenCode",
|
||||
Description: "Anomaly's open-source coding agent",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"hermes": {
|
||||
Name: "hermes",
|
||||
DisplayName: "Hermes Agent",
|
||||
Description: "Self-improving AI agent built by Nous Research",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"droid": {
|
||||
Name: "droid",
|
||||
DisplayName: "Droid",
|
||||
Description: "Factory's coding agent across terminal and IDEs",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
"pi": {
|
||||
Name: "pi",
|
||||
DisplayName: "Pi",
|
||||
Description: "Minimal AI agent toolkit with plugin support",
|
||||
Selectable: true,
|
||||
Changeable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func findMenuCursorByIntegration(items []menuItem, name string) int {
|
||||
for i, item := range items {
|
||||
if item.integration == name {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func integrationSequence(items []menuItem) []string {
|
||||
sequence := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch {
|
||||
case item.isRunModel:
|
||||
sequence = append(sequence, "run")
|
||||
case item.isOthers:
|
||||
sequence = append(sequence, "more")
|
||||
case item.integration != "":
|
||||
sequence = append(sequence, item.integration)
|
||||
}
|
||||
}
|
||||
return sequence
|
||||
}
|
||||
|
||||
func compareStrings(got, want []string) string {
|
||||
return cmp.Diff(want, got)
|
||||
}
|
||||
|
||||
func expectedCollapsedSequence(state *launch.LauncherState) []string {
|
||||
sequence := []string{"run"}
|
||||
for _, item := range pinnedIntegrationItems(state) {
|
||||
sequence = append(sequence, item.integration)
|
||||
}
|
||||
if len(otherIntegrationItems(state)) > 0 {
|
||||
sequence = append(sequence, "more")
|
||||
}
|
||||
return sequence
|
||||
}
|
||||
|
||||
func expectedExpandedSequence(state *launch.LauncherState) []string {
|
||||
sequence := []string{"run"}
|
||||
for _, item := range pinnedIntegrationItems(state) {
|
||||
sequence = append(sequence, item.integration)
|
||||
}
|
||||
for _, item := range otherIntegrationItems(state) {
|
||||
sequence = append(sequence, item.integration)
|
||||
}
|
||||
return sequence
|
||||
}
|
||||
|
||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
menu := newModel(state)
|
||||
wantPrefix := []string{"run", "claude", "codex-app", "hermes", "openclaw"}
|
||||
if findMenuCursorByIntegration(menu.items, "codex-app") == -1 {
|
||||
wantPrefix = []string{"run", "claude", "hermes", "openclaw", "opencode"}
|
||||
}
|
||||
if got := integrationSequence(menu.items); len(got) < len(wantPrefix) {
|
||||
t.Fatalf("expected at least %d menu items, got %v", len(wantPrefix), got)
|
||||
} else if diff := compareStrings(got[:len(wantPrefix)], wantPrefix); diff != "" {
|
||||
t.Fatalf("unexpected primary TUI order: %s", diff)
|
||||
}
|
||||
|
||||
view := menu.View()
|
||||
for _, want := range []string{"Chat with a model", "Launch Claude Code", "Launch Hermes Agent", "Launch OpenClaw", "More..."} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||
}
|
||||
}
|
||||
if findMenuCursorByIntegration(menu.items, "codex-app") != -1 && !strings.Contains(view, "Launch Codex App") {
|
||||
t.Fatalf("expected menu view to contain Codex App\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "Launch Claude Desktop") {
|
||||
t.Fatalf("expected hidden Claude Desktop to be absent\n%s", view)
|
||||
}
|
||||
wantOrder := expectedCollapsedSequence(state)
|
||||
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||
t.Fatalf("unexpected pinned order: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
overflow := otherIntegrationItems(state)
|
||||
if len(overflow) == 0 {
|
||||
t.Fatal("expected at least one overflow integration")
|
||||
}
|
||||
state.LastSelection = overflow[0].integration
|
||||
|
||||
menu := newModel(state)
|
||||
if !menu.showOthers {
|
||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, overflow[0].title) {
|
||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "More...") {
|
||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||
}
|
||||
wantOrder := expectedExpandedSequence(state)
|
||||
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||
t.Fatalf("unexpected expanded order: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on run to select run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnRunSelectsChangeRun(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionRunModel, ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on run to select change-run action, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuEnterOnIntegrationSelectsLaunch(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
|
||||
if menu.cursor == -1 {
|
||||
t.Fatal("expected claude menu item")
|
||||
}
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude"}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected enter on integration to launch, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuRightOnIntegrationSelectsConfigure(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
|
||||
if menu.cursor == -1 {
|
||||
t.Fatal("expected claude menu item")
|
||||
}
|
||||
updated, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
got := updated.(model)
|
||||
want := TUIAction{Kind: TUIActionLaunchIntegration, Integration: "claude", ForceConfigure: true}
|
||||
if !got.selected || got.action != want {
|
||||
t.Fatalf("expected right on integration to configure, got selected=%v action=%v", got.selected, got.action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuIgnoresDisabledActions(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
claude := state.Integrations["claude"]
|
||||
claude.Selectable = false
|
||||
claude.Changeable = false
|
||||
state.Integrations["claude"] = claude
|
||||
|
||||
menu := newModel(state)
|
||||
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
|
||||
if menu.cursor == -1 {
|
||||
t.Fatal("expected claude menu item")
|
||||
}
|
||||
|
||||
updatedEnter, _ := menu.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if updatedEnter.(model).selected {
|
||||
t.Fatal("expected non-selectable integration to ignore enter")
|
||||
}
|
||||
|
||||
updatedRight, _ := menu.Update(tea.KeyMsg{Type: tea.KeyRight})
|
||||
if updatedRight.(model).selected {
|
||||
t.Fatal("expected non-changeable integration to ignore right")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsCurrentModelSuffixes(t *testing.T) {
|
||||
menu := newModel(launcherTestState())
|
||||
runView := menu.View()
|
||||
if !strings.Contains(runView, "(qwen3:8b)") {
|
||||
t.Fatalf("expected run row to show current model suffix\n%s", runView)
|
||||
}
|
||||
|
||||
menu.cursor = findMenuCursorByIntegration(menu.items, "claude")
|
||||
if menu.cursor == -1 {
|
||||
t.Fatal("expected claude menu item")
|
||||
}
|
||||
integrationView := menu.View()
|
||||
if !strings.Contains(integrationView, "(glm-5:cloud)") {
|
||||
t.Fatalf("expected integration row to show current model suffix\n%s", integrationView)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMenuShowsInstallStatusAndHint(t *testing.T) {
|
||||
state := launcherTestState()
|
||||
codex := state.Integrations["codex"]
|
||||
codex.Installed = false
|
||||
codex.Selectable = false
|
||||
codex.Changeable = false
|
||||
codex.InstallHint = "Install from https://example.com/codex"
|
||||
state.Integrations["codex"] = codex
|
||||
|
||||
state.LastSelection = "codex"
|
||||
menu := newModel(state)
|
||||
menu.cursor = findMenuCursorByIntegration(menu.items, "codex")
|
||||
if menu.cursor == -1 {
|
||||
t.Fatal("expected codex menu item in overflow section")
|
||||
}
|
||||
view := menu.View()
|
||||
if !strings.Contains(view, "(not installed)") {
|
||||
t.Fatalf("expected not-installed marker\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, codex.InstallHint) {
|
||||
t.Fatalf("expected install hint in description\n%s", view)
|
||||
}
|
||||
}
|
||||
63
cmd/warn_thinking_test.go
Normal file
63
cmd/warn_thinking_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
ensureThinkingSupport(t.Context(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user