llama.cpp verification source 2026-05-22
Some checks are pending
Copilot Setup Steps / copilot-setup-steps (push) Waiting to run
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Waiting to run
Python check requirements.txt / check-requirements (push) Waiting to run
Python Type-Check / python type-check (push) Waiting to run
Update Operations Documentation / update-ops-docs (push) Waiting to run

This commit is contained in:
2026-05-22 16:44:08 +08:00
commit 8e5a449007
2740 changed files with 1155720 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
# server-context containing the core server logic, used by llama-server and CLI
set(TARGET server-context)
add_library(${TARGET} STATIC
server-chat.cpp
server-chat.h
server-task.cpp
server-task.h
server-queue.cpp
server-queue.h
server-common.cpp
server-common.h
server-context.cpp
server-context.h
server-tools.cpp
server-tools.h
)
if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
target_include_directories(${TARGET} PRIVATE ../mtmd)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PUBLIC llama-common mtmd ${CMAKE_THREAD_LIBS_INIT})
# llama-server executable
set(TARGET llama-server)
set(TARGET_SRCS
server.cpp
server-http.cpp
server-http.h
server-models.cpp
server-models.h
)
option(LLAMA_BUILD_WEBUI "Build the embedded Web UI" ON)
if (LLAMA_BUILD_WEBUI)
set(PUBLIC_ASSETS
index.html
bundle.js
bundle.css
loading.html
)
foreach(asset ${PUBLIC_ASSETS})
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
list(APPEND TARGET_SRCS ${output})
add_custom_command(
DEPENDS "${input}"
OUTPUT "${output}"
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
)
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
endforeach()
add_definitions(-DLLAMA_BUILD_WEBUI)
else()
endif()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)
target_include_directories(${TARGET} PRIVATE ../mtmd)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE server-context PUBLIC llama-common cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

264
tools/server/README-dev.md Normal file
View File

@@ -0,0 +1,264 @@
# llama-server Development Documentation
This document provides an in-depth technical overview of `llama-server`, intended for maintainers and contributors.
If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead.
## Scope of features
In-scope types of feature:
- Backend:
- Basic inference features: text completion, embeddings output
- Chat-oriented features: chat completion, tool calling
- Third-party API compatibility, e.g. OAI-compat, Anthropic-compat
- Multimodal input/output
- Memory management: save/load state, context checkpoints
- Model management
- Features that are required by the Web UI
- Frontend:
- Chat-oriented features, example: basic chat, image upload, edit messages
- Agentic features, example: MCP
- Model management
Note: For security reasons, features that require reading or writing external files must be **disabled by default**. This covers features like: MCP, model save/load
Out-of-scope features:
- Backend:
- Features that require a loop of external API calls, e.g. server-side agentic loop. This is because external API calls in C++ are costly to maintain. Any complex third-party logic should be implemented outside of server code.
- Features that expose the internal state of the model to the API, example: getting the intermediate activation from API. This is because llama.cpp doesn't support a stable API for doing this, and relying on `eval_callback` can make it complicated to maintain as this API is not intended to be used in multi-sequence setup.
- Model-specific features. All API calls and features must remain model-agnostic.
- Frontend:
- Third-party plugins, it is costly to maintain a public plugin API for such features. Instead, users can make their own MCP server for their needs.
- Customizable themes, it is also costly to maintain. While we do focus on the aesthetic, we try to achieve this by perfecting a small set of themes.
- Browser-specific features, example: [Chrome's built-in AI API](https://developer.chrome.com/docs/ai/built-in-apis).
## Backend
### Overview
The server supports two primary operating modes:
- **Inference mode**: The default mode for performing inference with a single loaded GGUF model.
- **Router mode**: Enables management of multiple inference server instances behind a single API endpoint. Requests are automatically routed to the appropriate backend instance based on the requested model.
The core architecture consists of the following components:
- `server_context`: Holds the primary inference state, including the main `llama_context` and all active slots.
- `server_slot`: An abstraction over a single “sequence” in llama.cpp, responsible for managing individual parallel inference requests.
- `server_routes`: Middleware layer between `server_context` and the HTTP interface; handles JSON parsing/formatting and request routing logic.
- `server_http_context`: Implements the HTTP server using `cpp-httplib`.
- `server_queue`: Thread-safe queue used by HTTP workers to submit new tasks to `server_context`.
- `server_response`: Thread-safe queue used by `server_context` to return results to HTTP workers.
- `server_response_reader`: Higher-level wrapper around the two queues above for cleaner code.
- `server_task`: Unit of work pushed into `server_queue`.
- `server_task_result`: Unit of result pushed into `server_response`.
- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`.
- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation.
- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`.
```mermaid
graph TD
API_User <--> server_http_context
server_http_context <-- router mode --> server_models
server_http_context <-- inference mode --> server_routes
server_routes -- server_task --> server_queue
subgraph server_context
server_queue --> server_slot
server_slot -- server_task_result --> server_response
server_slot[multiple server_slot]
end
server_response --> server_routes
```
### Batching
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
### Thread Management
`server_context` runs on a dedicated single thread. Because it is single-threaded, heavy post-processing (especially after token generation) should be avoided, as it directly impacts multi-sequence throughput.
Each incoming HTTP request is handled by its own thread managed by the HTTP library. The following operations are performed in HTTP worker threads:
- JSON request parsing
- Chat template application
- Tokenization
- Conversion of `server_task_result` into final JSON response
- Error formatting into JSON
- Tracking of partial/incremental responses (e.g., streaming tool calls or reasoning steps)
**Best practices to follow:**
- All JSON formatting and chat template logic must stay in the HTTP layer.
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
### Example trace of a request
Here is an example trace of an API request for text completion:
- A request arrives at the HTTP layer.
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
- `server_res_generator` creates a new `task_result_state` for each task:
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
- `server_task` is moved into `server_queue` inside `server_context`.
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
- `update_slot()` processes the task as described in the "Batching" section above.
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
### Testing
`llama-server` includes an automated test suite based on `pytest`.
The framework automatically starts a `llama-server` instance, sends requests, and validates responses.
For detailed instructions, see the [test documentation](./tests/README.md).
### API for tools
This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
**GET /tools**
Get a list of tools, each tool has these fields:
- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file`
- `display_name` (string): the name to be displayed on UI. Example: `Read file`
- `type` (string): always be `"builtin"` for now
- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"`
- `definition` (object): the OAI-compat definition of this tool
**POST /tools**
Invoke a tool call, request body is a JSON object with:
- `tool` (string): the name of the tool
- `params` (object): a mapping from argument name (string) to argument value
Returns JSON object. There are two response formats:
Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example:
```json
{
"plain_text_response": "this is a text response"
}
```
The client should extract this value and place it inside message content (note: content is no longer a JSON), example
```json
{
"role": "tool",
"content": "this is a text response"
}
```
Format 2: Normal JSON response, example:
```json
{
"error": "cannot open this file"
}
```
That requires `JSON.stringify` when formatted to message content:
```json
{
"role": "tool",
"content": "{\"error\":\"cannot open this file\"}"
}
```
### Notable Related PRs
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
- Parallel decoding support: https://github.com/ggml-org/llama.cpp/pull/3228
- Refactor introducing `server_queue` and `server_response`: https://github.com/ggml-org/llama.cpp/pull/5065
- Reranking endpoint: https://github.com/ggml-org/llama.cpp/pull/9510
- Multimodal model support (`libmtmd`): https://github.com/ggml-org/llama.cpp/pull/12898
- Unified KV cache handling: https://github.com/ggml-org/llama.cpp/pull/16736
- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
- INI presets: https://github.com/ggml-org/llama.cpp/pull/17859 (+ refactoring: https://github.com/ggml-org/llama.cpp/pull/18169)
- Sleeping mode: https://github.com/ggml-org/llama.cpp/pull/18228
## Web UI
The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation.
The SvelteKit-based Web UI is introduced in this PR: https://github.com/ggml-org/llama.cpp/pull/14839
### Features
- **Chat interface** with streaming responses
- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection
- **Modality validation** - ensures selected model supports conversation's attachments (images, audio)
- **Conversation management** - branching, regeneration, editing with history preservation
- **Attachment support** - images, audio, PDFs (with vision/text fallback)
- **Configurable parameters** - temperature, top_p, etc. synced with server defaults
- **Dark/light theme**
### Tech Stack
- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state
- **TailwindCSS** + **shadcn-svelte** - styling and UI components
- **Vite** - build tooling
- **IndexedDB** (Dexie) - local storage for conversations
- **LocalStorage** - user settings persistence
### Architecture
The WebUI follows a layered architecture:
```
Routes → Components → Hooks → Stores → Services → Storage/API
```
- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`)
- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`)
- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`)
For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/):
- `high-level-architecture.mmd` - full architecture with all modules
- `high-level-architecture-simplified.mmd` - simplified overview
- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode
- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode
- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.)
### Development
```sh
# make sure you have Node.js installed
cd tools/server/webui
npm i
# run dev server (with hot reload)
npm run dev
# run tests
npm run test
# build production bundle
npm run build
```
After `public/index.html` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI.
**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development.

1863
tools/server/README.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
### Server benchmark tools
Benchmark is using [k6](https://k6.io/).
##### Install k6 and sse extension
SSE is not supported by default in k6, you have to build k6 with the [xk6-sse](https://github.com/phymbert/xk6-sse) extension.
Example (assuming golang >= 1.21 is installed):
```shell
go install go.k6.io/xk6/cmd/xk6@latest
$GOPATH/bin/xk6 build master \
--with github.com/phymbert/xk6-sse
```
#### Download a dataset
This dataset was originally proposed in [vLLM benchmarks](https://github.com/vllm-project/vllm/blob/main/benchmarks/README.md).
```shell
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
#### Download a model
Example for PHI-2
```shell
../../../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
```
#### Start the server
The server must answer OAI Chat completion requests on `http://localhost:8080/v1` or according to the environment variable `SERVER_BENCH_URL`.
Example:
```shell
llama-server --host localhost --port 8080 \
--model ggml-model-q4_0.gguf \
--cont-batching \
--metrics \
--parallel 8 \
--batch-size 512 \
--ctx-size 4096 \
-ngl 33
```
#### Run the benchmark
For 500 chat completions request with 8 concurrent users during maximum 10 minutes, run:
```shell
./k6 run script.js --duration 10m --iterations 500 --vus 8
```
The benchmark values can be overridden with:
- `SERVER_BENCH_URL` server url prefix for chat completions, default `http://localhost:8080/v1`
- `SERVER_BENCH_N_PROMPTS` total prompts to randomly select in the benchmark, default `480`
- `SERVER_BENCH_MODEL_ALIAS` model alias to pass in the completion request, default `my-model`
- `SERVER_BENCH_MAX_TOKENS` max tokens to predict, default: `512`
- `SERVER_BENCH_DATASET` path to the benchmark dataset file
- `SERVER_BENCH_MAX_PROMPT_TOKENS` maximum prompt tokens to filter out in the dataset: default `1024`
- `SERVER_BENCH_MAX_CONTEXT` maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens, default `2048`
Note: the local tokenizer is just a string space split, real number of tokens will differ.
Or with [k6 options](https://k6.io/docs/using-k6/k6-options/reference/):
```shell
SERVER_BENCH_N_PROMPTS=500 k6 run script.js --duration 10m --iterations 500 --vus 8
```
To [debug http request](https://k6.io/docs/using-k6/http-debugging/) use `--http-debug="full"`.
#### Metrics
Following metrics are available computed from the OAI chat completions response `usage`:
- `llamacpp_tokens_second` Trend of `usage.total_tokens / request duration`
- `llamacpp_prompt_tokens` Trend of `usage.prompt_tokens`
- `llamacpp_prompt_tokens_total_counter` Counter of `usage.prompt_tokens`
- `llamacpp_completion_tokens` Trend of `usage.completion_tokens`
- `llamacpp_completion_tokens_total_counter` Counter of `usage.completion_tokens`
- `llamacpp_completions_truncated_rate` Rate of completions truncated, i.e. if `finish_reason === 'length'`
- `llamacpp_completions_stop_rate` Rate of completions stopped by the model, i.e. if `finish_reason === 'stop'`
The script will fail if too many completions are truncated, see `llamacpp_completions_truncated_rate`.
K6 metrics might be compared against [server metrics](../README.md), with:
```shell
curl http://localhost:8080/metrics
```
### Using the CI python script
The `bench.py` script does several steps:
- start the server
- define good variable for k6
- run k6 script
- extract metrics from prometheus
It aims to be used in the CI, but you can run it manually:
```shell
LLAMA_SERVER_BIN_PATH=../../../cmake-build-release/bin/llama-server python bench.py \
--runner-label local \
--name local \
--branch `git rev-parse --abbrev-ref HEAD` \
--commit `git rev-parse HEAD` \
--scenario script.js \
--duration 5m \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model-path-prefix models \
--parallel 4 \
-ngl 33 \
--batch-size 2048 \
--ubatch-size 256 \
--ctx-size 4096 \
--n-prompts 200 \
--max-prompt-tokens 256 \
--max-tokens 256
```

322
tools/server/bench/bench.py Normal file
View File

@@ -0,0 +1,322 @@
from __future__ import annotations
import argparse
import json
import os
import re
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
from contextlib import closing
from datetime import datetime
import matplotlib
import matplotlib.dates
import matplotlib.pyplot as plt
import requests
from statistics import mean
def main(args_in: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Start server benchmark scenario")
parser.add_argument("--name", type=str, help="Bench name", required=True)
parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
parser.add_argument("--branch", type=str, help="Branch name", default="detached")
parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
parser.add_argument("--port", type=int, help="Server listen host", default="8080")
parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
parser.add_argument("--n-prompts", type=int,
help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
parser.add_argument("--max-prompt-tokens", type=int,
help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
required=True)
parser.add_argument("--max-tokens", type=int,
help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
required=True)
parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
args = parser.parse_args(args_in)
start_time = time.time()
# Start the server and performance scenario
try:
server_process = start_server(args)
except Exception:
print("bench: server start error :")
traceback.print_exc(file=sys.stdout)
sys.exit(1)
# start the benchmark
iterations = 0
data = {}
try:
start_benchmark(args)
with open("results.github.env", 'w') as github_env:
# parse output
with open('k6-results.json', 'r') as bench_results:
# Load JSON data from file
data = json.load(bench_results)
for metric_name in data['metrics']:
for metric_metric in data['metrics'][metric_name]:
value = data['metrics'][metric_name][metric_metric]
if isinstance(value, float) or isinstance(value, int):
value = round(value, 2)
data['metrics'][metric_name][metric_metric]=value
github_env.write(
f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
iterations = data['root_group']['checks']['success completion']['passes']
except Exception:
print("bench: error :")
traceback.print_exc(file=sys.stdout)
# Stop the server
if server_process:
try:
print(f"bench: shutting down server pid={server_process.pid} ...")
if os.name == 'nt':
interrupt = signal.CTRL_C_EVENT
else:
interrupt = signal.SIGINT
server_process.send_signal(interrupt)
server_process.wait(0.5)
except subprocess.TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
server_process.kill() # SIGKILL
server_process.wait()
while is_server_listening(args.host, args.port):
time.sleep(0.1)
title = (f"llama.cpp {args.name} on {args.runner_label}\n "
f"duration={args.duration} {iterations} iterations")
xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
f"branch={args.branch} commit={args.commit}")
# Prometheus
end_time = time.time()
prometheus_metrics = {}
if is_server_listening("0.0.0.0", 9090):
metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
for metric in metrics:
resp = requests.get(f"http://localhost:9090/api/v1/query_range",
params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
with open(f"{metric}.json", 'w') as metric_json:
metric_json.write(resp.text)
if resp.status_code != 200:
print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
else:
metric_data = resp.json()
values = metric_data['data']['result'][0]['values']
timestamps, metric_values = zip(*values)
metric_values = [float(value) for value in metric_values]
prometheus_metrics[metric] = metric_values
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
plt.figure(figsize=(16, 10), dpi=80)
plt.plot(timestamps_dt, metric_values, label=metric)
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
plt.yticks(fontsize=12, alpha=.7)
ylabel = f"llamacpp:{metric}"
plt.title(title,
fontsize=14, wrap=True)
plt.grid(axis='both', alpha=.3)
plt.ylabel(ylabel, fontsize=22)
plt.xlabel(xlabel, fontsize=14, wrap=True)
plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
plt.gcf().autofmt_xdate()
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)
plt.gca().spines["left"].set_alpha(0.3)
# Save the plot as a jpg image
plt.savefig(f'{metric}.jpg', dpi=60)
plt.close()
# Mermaid format in case images upload failed
with open(f"{metric}.mermaid", 'w') as mermaid_f:
mermaid = (
f"""---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "{title}"
y-axis "llamacpp:{metric}"
x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
""")
mermaid_f.write(mermaid)
# 140 chars max for commit status description
bench_results = {
"i": iterations,
"req": {
"p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
"avg": round(data['metrics']["http_req_duration"]["avg"], 2),
},
"pp": {
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
},
"tg": {
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
},
}
with open("results.github.env", 'a') as github_env:
github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
github_env.write(f"BENCH_ITERATIONS={iterations}\n")
title = title.replace('\n', ' ')
xlabel = xlabel.replace('\n', ' ')
github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
def start_benchmark(args):
k6_path = './k6'
if 'BENCH_K6_BIN_PATH' in os.environ:
k6_path = os.environ['BENCH_K6_BIN_PATH']
k6_args = [
'run', args.scenario,
'--no-color',
'--no-connection-reuse',
'--no-vu-connection-reuse',
]
k6_args.extend(['--duration', args.duration])
k6_args.extend(['--iterations', args.n_prompts])
k6_args.extend(['--vus', args.parallel])
k6_args.extend(['--summary-export', 'k6-results.json'])
k6_args.extend(['--out', 'csv=k6-results.csv'])
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
print(f"bench: starting k6 with: {args}")
k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
if k6_completed.returncode != 0:
raise Exception("bench: unable to run k6")
def start_server(args):
server_process = start_server_background(args)
attempts = 0
max_attempts = 600
if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2
while not is_server_listening(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not started"
print(f"bench: waiting for server to start ...")
time.sleep(0.5)
attempts = 0
while not is_server_ready(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not ready"
print(f"bench: waiting for server to be ready ...")
time.sleep(0.5)
print("bench: server started and ready.")
return server_process
def start_server_background(args):
# Start the server
server_path = '../../../build/bin/llama-server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_args = [
'--host', args.host,
'--port', args.port,
]
server_args.extend(['--hf-repo', args.hf_repo])
server_args.extend(['--hf-file', args.hf_file])
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
server_args.extend(['--ctx-size', args.ctx_size])
server_args.extend(['--parallel', args.parallel])
server_args.extend(['--batch-size', args.batch_size])
server_args.extend(['--ubatch-size', args.ubatch_size])
server_args.extend(['--n-predict', args.max_tokens * 2])
server_args.append('--cont-batching')
server_args.append('--metrics')
server_args.append('--flash-attn')
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
pkwargs = {
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE
}
server_process = subprocess.Popen(
args,
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] # ty: ignore[no-matching-overload]
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
print(line.decode('utf-8'), end='', file=out_stream)
thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
thread_stdout.start()
thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
thread_stderr.start()
return server_process
def is_server_listening(server_fqdn, server_port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0
if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening
def is_server_ready(server_fqdn, server_port):
url = f"http://{server_fqdn}:{server_port}/health"
response = requests.get(url)
return response.status_code == 200
def escape_metric_name(metric_name):
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,9 @@
global:
scrape_interval: 10s
external_labels:
llamacpp: 'server'
scrape_configs:
- job_name: 'llama.cpp server'
static_configs:
- targets: ['localhost:8080']

View File

@@ -0,0 +1,2 @@
matplotlib
requests

View File

@@ -0,0 +1,162 @@
import sse from 'k6/x/sse'
import {check, sleep} from 'k6'
import {SharedArray} from 'k6/data'
import {Counter, Rate, Trend} from 'k6/metrics'
import exec from 'k6/execution';
// Server chat completions prefix
const server_url = __ENV.SERVER_BENCH_URL ? __ENV.SERVER_BENCH_URL : 'http://localhost:8080/v1'
// Number of total prompts in the dataset - default 10m / 10 seconds/request * number of users
const n_prompt = __ENV.SERVER_BENCH_N_PROMPTS ? parseInt(__ENV.SERVER_BENCH_N_PROMPTS) : 600 / 10 * 8
// Model name to request
const model = __ENV.SERVER_BENCH_MODEL_ALIAS ? __ENV.SERVER_BENCH_MODEL_ALIAS : 'my-model'
// Dataset path
const dataset_path = __ENV.SERVER_BENCH_DATASET ? __ENV.SERVER_BENCH_DATASET : './ShareGPT_V3_unfiltered_cleaned_split.json'
// Max tokens to predict
const max_tokens = __ENV.SERVER_BENCH_MAX_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_TOKENS) : 512
// Max prompt tokens
const n_prompt_tokens = __ENV.SERVER_BENCH_MAX_PROMPT_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_PROMPT_TOKENS) : 1024
// Max slot context
const n_ctx_slot = __ENV.SERVER_BENCH_MAX_CONTEXT ? parseInt(__ENV.SERVER_BENCH_MAX_CONTEXT) : 2048
export function setup() {
console.info(`Benchmark config: server_url=${server_url} n_prompt=${n_prompt} model=${model} dataset_path=${dataset_path} max_tokens=${max_tokens}`)
}
const data = new SharedArray('conversations', function () {
const tokenizer = (message) => message.split(/[\s,'".?]/)
return JSON.parse(open(dataset_path))
// Filter out the conversations with less than 2 turns.
.filter(data => data["conversations"].length >= 2)
.filter(data => data["conversations"][0]["from"] === "human")
.map(data => {
return {
prompt: data["conversations"][0]["value"],
n_prompt_tokens: tokenizer(data["conversations"][0]["value"]).length,
n_completion_tokens: tokenizer(data["conversations"][1]["value"]).length,
}
})
// Filter out too short sequences
.filter(conv => conv.n_prompt_tokens >= 4 && conv.n_completion_tokens >= 4)
// Filter out too long sequences.
.filter(conv => conv.n_prompt_tokens <= n_prompt_tokens && conv.n_prompt_tokens + conv.n_completion_tokens <= n_ctx_slot)
// Keep only first n prompts
.slice(0, n_prompt)
})
const llamacpp_prompt_tokens = new Trend('llamacpp_prompt_tokens')
const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
const llamacpp_completions_truncated_rate = new Rate('llamacpp_completions_truncated_rate')
const llamacpp_completions_stop_rate = new Rate('llamacpp_completions_stop_rate')
export const options = {
thresholds: {
llamacpp_completions_truncated_rate: [
// more than 80% of truncated input will abort the test
{threshold: 'rate < 0.8', abortOnFail: true, delayAbortEval: '1m'},
],
},
duration: '10m',
vus: 8,
}
export default function () {
const conversation = data[exec.scenario.iterationInInstance % data.length]
const payload = {
"messages": [
{
"role": "system",
"content": "You are ChatGPT, an AI assistant.",
},
{
"role": "user",
"content": conversation.prompt,
}
],
"model": model,
"stream": true,
"stream_options": {
"include_usage": true, // False to be supported in llama.cpp server
},
"seed": 42,
"max_tokens": max_tokens,
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
}
const params = {method: 'POST', body: JSON.stringify(payload)};
const startTime = new Date()
let promptEvalEndTime = null
let prompt_tokens = 0
let completions_tokens = 0
let finish_reason = null
const res = sse.open(`${server_url}/chat/completions`, params, function (client) {
client.on('event', function (event) {
if (promptEvalEndTime == null) {
promptEvalEndTime = new Date()
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
}
if (event.data === '[DONE]' || event.data === '') {
return
}
let chunk = JSON.parse(event.data)
if (chunk.choices && chunk.choices.length > 0) {
let choice = chunk.choices[0]
if (choice.finish_reason) {
finish_reason = choice.finish_reason
}
}
if (chunk.usage) {
prompt_tokens = chunk.usage.prompt_tokens
llamacpp_prompt_tokens.add(prompt_tokens)
llamacpp_prompt_tokens_total_counter.add(prompt_tokens)
completions_tokens = chunk.usage.completion_tokens
llamacpp_completion_tokens.add(completions_tokens)
llamacpp_completion_tokens_total_counter.add(completions_tokens)
}
})
client.on('error', function (e) {
console.log('An unexpected error occurred: ', e.error());
throw e;
})
})
check(res, {'success completion': (r) => r.status === 200})
const endTime = new Date()
const promptEvalTime = promptEvalEndTime - startTime
if (promptEvalTime > 0) {
llamacpp_prompt_processing_second.add(prompt_tokens / (promptEvalEndTime - startTime) * 1.e3)
}
const completion_time = endTime - promptEvalEndTime
if (completions_tokens > 0 && completion_time > 0) {
llamacpp_tokens_second.add(completions_tokens / completion_time * 1.e3)
}
llamacpp_completions_truncated_rate.add(finish_reason === 'length')
llamacpp_completions_stop_rate.add(finish_reason === 'stop')
sleep(0.3)
}

109
tools/server/chat-llama2.sh Executable file
View File

@@ -0,0 +1,109 @@
#!/usr/bin/env bash
API_URL="${API_URL:-http://127.0.0.1:8080}"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}
trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}
format_prompt() {
if [[ "${#CHAT[@]}" -eq 0 ]]; then
echo -n "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>"
else
LAST_INDEX=$(( ${#CHAT[@]} - 1 ))
echo -n "${CHAT[$LAST_INDEX]}\n[INST] $1 [/INST]"
fi
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 1024,
stop: ["[INST]"],
stream: true
}')"
# Create a temporary file to hold the Python output
TEMPFILE=$(mktemp)
exec 3< <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")
python -c "
import json
import sys
answer = ''
while True:
line = sys.stdin.readline()
if not line:
break
if line.startswith('data: '):
json_content = line[6:].strip()
content = json.loads(json_content)['content']
sys.stdout.write(content)
sys.stdout.flush()
answer += content
answer = answer.rstrip('\n')
# Write the answer to the temporary file
with open('$TEMPFILE', 'w') as f:
f.write(answer)
" <&3
exec 3<&-
# Read the answer from the temporary file
ANSWER=$(cat $TEMPFILE)
# Clean up the temporary file
rm $TEMPFILE
printf "\n"
CHAT+=("$1" "$(trim "$ANSWER")")
}
while true; do
echo -en "\033[0;32m" # Green color
read -r -e -p "> " QUESTION
echo -en "\033[0m" # Reset color
chat_completion "${QUESTION}"
done

131
tools/server/chat.mjs Normal file
View File

@@ -0,0 +1,131 @@
import * as readline from 'node:readline'
import { stdin, stdout } from 'node:process'
import { readFileSync } from 'node:fs'
import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs'
const args = process.argv.slice(2);
const grammarJsonSchemaFile = args.find(
(_, index) => args[index - 1] === "--grammar-json-schema"
);
const no_cached_prompt = args.find(
(_, index) => args[index - 1] === "--no-cache-prompt"
) ?? "false";
const grammarFile = args.find((_, index) => args[index - 1] === "--grammar");
// Example usage: function,arguments
const grammarJsonSchemaPropOrder = args.find(
(_, index) => args[index - 1] === "--grammar-json-schema-prop-order"
);
const propOrder = grammarJsonSchemaPropOrder
? grammarJsonSchemaPropOrder
.split(",")
.reduce((acc, cur, index) => ({ ...acc, [cur]: index }), {})
: {};
let grammar = null
if (grammarJsonSchemaFile) {
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
converter.visit(schema, '')
grammar = converter.formatGrammar()
}
if (grammarFile) {
grammar = readFileSync(grammarFile, 'utf-8')
}
// for cached prompt
let slot_id = -1;
const API_URL = 'http://127.0.0.1:8080'
const chat = [
{
human: "Hello, Assistant.",
assistant: "Hello. How may I help you today?"
},
{
human: "Please tell me the largest city in Europe.",
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia."
},
]
const instruction = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.`
function format_prompt(question) {
return `${instruction}\n${
chat.map(m =>`### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
}\n### Human: ${question}\n### Assistant:`
}
async function tokenize(content) {
const result = await fetch(`${API_URL}/tokenize`, {
method: 'POST',
body: JSON.stringify({ content })
})
if (!result.ok) {
return []
}
return await result.json().tokens
}
const n_keep = await tokenize(instruction).length
async function chat_completion(question) {
const result = await fetch(`${API_URL}/completion`, {
method: 'POST',
body: JSON.stringify({
prompt: format_prompt(question),
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: n_keep,
n_predict: 256,
cache_prompt: no_cached_prompt === "false",
slot_id: slot_id,
stop: ["\n### Human:"], // stop completion after generating this
grammar,
stream: true,
})
})
if (!result.ok) {
return
}
let answer = ''
for await (var chunk of result.body) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
slot_id = message.slot_id
answer += message.content
process.stdout.write(message.content)
if (message.stop) {
if (message.truncated) {
chat.shift()
}
break
}
}
}
process.stdout.write('\n')
chat.push({ human: question, assistant: answer.trimStart() })
}
const rl = readline.createInterface({ input: stdin, output: stdout });
const readlineQuestion = (rl, query, options) => new Promise((resolve, reject) => {
rl.question(query, options, resolve)
});
while(true) {
const question = await readlineQuestion(rl, '> ')
await chat_completion(question)
}

80
tools/server/chat.sh Executable file
View File

@@ -0,0 +1,80 @@
#!/usr/bin/env bash
API_URL="${API_URL:-http://127.0.0.1:8080}"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
"Please tell me the largest city in Europe."
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}
trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}
format_prompt() {
echo -n "${INSTRUCTION}"
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 256,
cache_prompt: true,
stop: ["\n### Human:"],
stream: true
}')"
ANSWER=''
while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")
printf "\n"
CHAT+=("$1" "$(trim "$ANSWER")")
}
while true; do
read -r -e -p "> " QUESTION
chat_completion "${QUESTION}"
done

File diff suppressed because one or more lines are too long

13285
tools/server/public/bundle.js Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="5">
</head>
<body>
<div id="loading">
The model is loading. Please wait.<br/>
The user interface will appear soon.
</div>
</body>
</html>

View File

@@ -0,0 +1,630 @@
#include "server-chat.h"
#include "server-common.h"
#include <sstream>
json server_chat_convert_responses_to_chatcmpl(const json & response_body) {
if (!response_body.contains("input")) {
throw std::invalid_argument("'input' is required");
}
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
}
const json input_value = response_body.at("input");
json chatcmpl_body = response_body;
chatcmpl_body.erase("input");
std::vector<json> chatcmpl_messages;
if (response_body.contains("instructions")) {
chatcmpl_messages.push_back({
{"role", "system"},
{"content", json_value(response_body, "instructions", std::string())},
});
chatcmpl_body.erase("instructions");
}
if (input_value.is_string()) {
// #responses_create-input-text_input
chatcmpl_messages.push_back({
{"role", "user"},
{"content", input_value},
});
} else if (input_value.is_array()) {
// #responses_create-input-input_item_list
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_array();
};
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
return j.contains(key) && j.at(key).is_string();
};
for (json item : input_value) {
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
if (exists_and_is_string(item, "content")) {
// #responses_create-input-input_item_list-input_message-content-text_input
// Only "Input message" contains item["content"]::string
// After converting item["content"]::string to item["content"]::array,
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
item["content"] = json::array({
json {
{"text", item.at("content")},
{"type", "input_text"}
}
});
}
if (exists_and_is_array(item, "content") &&
exists_and_is_string(item, "role") &&
(item.at("role") == "user" ||
item.at("role") == "system" ||
item.at("role") == "developer")
) {
// #responses_create-input-input_item_list-item-input_message
std::vector<json> chatcmpl_content;
for (const json & input_item : item.at("content")) {
const std::string type = json_value(input_item, "type", std::string());
if (type == "input_text") {
if (!input_item.contains("text")) {
throw std::invalid_argument("'Input text' requires 'text'");
}
chatcmpl_content.push_back({
{"text", input_item.at("text")},
{"type", "text"},
});
} else if (type == "input_image") {
// While `detail` is marked as required,
// it has default value("auto") and can be omitted.
if (!input_item.contains("image_url")) {
throw std::invalid_argument("'image_url' is required");
}
chatcmpl_content.push_back({
{"image_url", json {
{"url", input_item.at("image_url")}
}},
{"type", "image_url"},
});
} else if (type == "input_file") {
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
} else {
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
}
}
if (item.contains("type")) {
item.erase("type");
}
if (item.contains("status")) {
item.erase("status");
}
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
} else if (exists_and_is_string(item, "role") &&
item.at("role") == "assistant" &&
exists_and_is_string(item, "type") &&
item.at("type") == "message"
) {
// #responses_create-input-input_item_list-item-output_message
auto chatcmpl_content = json::array();
// Handle both string content and array content
if (item.contains("content") && item.at("content").is_string()) {
// String content - convert to text content part
chatcmpl_content.push_back({
{"text", item.at("content")},
{"type", "text"},
});
} else if (exists_and_is_array(item, "content")) {
// Array content - process each item
for (const auto & output_text : item.at("content")) {
const std::string type = json_value(output_text, "type", std::string());
if (type == "output_text" || type == "input_text") {
// Accept both output_text and input_text (string content gets converted to input_text)
if (!exists_and_is_string(output_text, "text")) {
throw std::invalid_argument("'Output text' requires 'text'");
}
chatcmpl_content.push_back({
{"text", output_text.at("text")},
{"type", "text"},
});
} else if (type == "refusal") {
if (!exists_and_is_string(output_text, "refusal")) {
throw std::invalid_argument("'Refusal' requires 'refusal'");
}
chatcmpl_content.push_back({
{"refusal", output_text.at("refusal")},
{"type", "refusal"},
});
} else {
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
}
}
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "content")) {
prev_msg["content"] = json::array();
}
auto & prev_content = prev_msg["content"];
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
} else {
item.erase("status");
item.erase("type");
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
}
} else if (exists_and_is_string(item, "arguments") &&
exists_and_is_string(item, "call_id") &&
exists_and_is_string(item, "name") &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call"
) {
// #responses_create-input-input_item_list-item-function_tool_call
json tool_call = {
{"function", json {
{"arguments", item.at("arguments")},
{"name", item.at("name")},
}},
{"id", item.at("call_id")},
{"type", "function"},
};
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "tool_calls")) {
prev_msg["tool_calls"] = json::array();
}
prev_msg["tool_calls"].push_back(tool_call);
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"tool_calls", json::array({tool_call})}
});
}
} else if (exists_and_is_string(item, "call_id") &&
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
exists_and_is_string(item, "type") &&
item.at("type") == "function_call_output"
) {
// #responses_create-input-input_item_list-item-function_tool_call_output
if (item.at("output").is_string()) {
chatcmpl_messages.push_back(json {
{"content", item.at("output")},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
} else {
json chatcmpl_outputs = item.at("output");
for (json & chatcmpl_output : chatcmpl_outputs) {
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
throw std::invalid_argument("Output of tool call should be 'Input text'");
}
chatcmpl_output["type"] = "text";
}
chatcmpl_messages.push_back(json {
{"content", chatcmpl_outputs},
{"role", "tool"},
{"tool_call_id", item.at("call_id")},
});
}
} else if (exists_and_is_array(item, "summary") &&
exists_and_is_string(item, "type") &&
item.at("type") == "reasoning") {
// #responses_create-input-input_item_list-item-reasoning
if (!exists_and_is_array(item, "content")) {
throw std::invalid_argument("item['content'] is not an array");
}
if (item.at("content").empty()) {
throw std::invalid_argument("item['content'] is empty");
}
if (!exists_and_is_string(item.at("content")[0], "text")) {
throw std::invalid_argument("item['content']['text'] is not a string");
}
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"content", json::array()},
{"reasoning_content", item.at("content")[0].at("text")},
});
}
} else {
throw std::invalid_argument("Cannot determine type of 'item'");
}
}
} else {
throw std::invalid_argument("'input' must be a string or array of objects");
}
chatcmpl_body["messages"] = chatcmpl_messages;
if (response_body.contains("tools")) {
if (!response_body.at("tools").is_array()) {
throw std::invalid_argument("'tools' must be an array of objects");
}
std::vector<json> chatcmpl_tools;
for (json resp_tool : response_body.at("tools")) {
json chatcmpl_tool;
if (json_value(resp_tool, "type", std::string()) != "function") {
throw std::invalid_argument("'type' of tool must be 'function'");
}
resp_tool.erase("type");
chatcmpl_tool["type"] = "function";
if (!resp_tool.contains("strict")) {
resp_tool["strict"] = true;
}
chatcmpl_tool["function"] = resp_tool;
chatcmpl_tools.push_back(chatcmpl_tool);
}
chatcmpl_body.erase("tools");
chatcmpl_body["tools"] = chatcmpl_tools;
}
if (response_body.contains("max_output_tokens")) {
chatcmpl_body.erase("max_output_tokens");
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
}
return chatcmpl_body;
}
// Edits the cch section of an "x-anthropic-billing-header" system prompt.
// Does nothing to any other prompt.
//
// This is a claude message with a "cch=ef01a" attribute that breaks prefix caching.
// The cch stamp is a whitebox end-to-end integrity hint. It's not meaningful as a
// system prompt data, particularly to llama.cpp, but its presence means the prefix
// cache will not get past it: It changes on each request.
//
// Reference: https://github.com/ggml-org/llama.cpp/pull/21793
// Example header:
// ```
// x-anthropic-billing-header: cc_version=2.1.101.e51; cc_entrypoint=cli; cch=a5145;You are Claude Code, Anthropic's official CLI for Claude.
// ^^^^^
// ```
static void normalize_anthropic_billing_header(std::string & system_text) {
if (system_text.rfind("x-anthropic-billing-header:", 0) != 0) {
return;
}
const size_t header_prefix_length = strlen("x-anthropic-billing-header:");
const size_t cch_length = 5;
const size_t index_cch = system_text.find("cch=", header_prefix_length);
if (index_cch == std::string::npos) {
return;
}
const size_t index_replace = index_cch + 4;
if (index_replace + cch_length < system_text.length() && system_text[index_replace + cch_length] == ';') {
for (size_t i = 0; i < cch_length; ++i) {
system_text[index_replace + i] = 'f';
}
} else {
LOG_ERR("anthropic string not as expected: %s", system_text.c_str());
}
}
json server_chat_convert_anthropic_to_oai(const json & body) {
json oai_body;
// Convert system prompt
json oai_messages = json::array();
auto system_param = json_value(body, "system", json());
if (!system_param.is_null()) {
std::string system_content;
if (system_param.is_string()) {
system_content = system_param.get<std::string>();
normalize_anthropic_billing_header(system_content);
} else if (system_param.is_array()) {
for (const auto & block : system_param) {
if (json_value(block, "type", std::string()) == "text") {
auto system_text = json_value(block, "text", std::string());
normalize_anthropic_billing_header(system_text);
system_content += system_text;
}
}
}
oai_messages.push_back({
{"role", "system"},
{"content", system_content}
});
}
// Convert messages
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
}
const json & messages = body.at("messages");
if (messages.is_array()) {
for (const auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (!msg.contains("content")) {
if (role == "assistant") {
continue;
}
oai_messages.push_back(msg);
continue;
}
const json & content = msg.at("content");
if (content.is_string()) {
oai_messages.push_back(msg);
continue;
}
if (!content.is_array()) {
oai_messages.push_back(msg);
continue;
}
json tool_calls = json::array();
json converted_content = json::array();
json tool_results = json::array();
std::string reasoning_content;
bool has_tool_calls = false;
for (const auto & block : content) {
std::string type = json_value(block, "type", std::string());
if (type == "text") {
converted_content.push_back(block);
} else if (type == "thinking") {
reasoning_content += json_value(block, "thinking", std::string());
} else if (type == "image") {
json source = json_value(block, "source", json::object());
std::string source_type = json_value(source, "type", std::string());
if (source_type == "base64") {
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
std::string data = json_value(source, "data", std::string());
std::ostringstream ss;
ss << "data:" << media_type << ";base64," << data;
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", ss.str()}
}}
});
} else if (source_type == "url") {
std::string url = json_value(source, "url", std::string());
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", url}
}}
});
}
} else if (type == "tool_use") {
tool_calls.push_back({
{"id", json_value(block, "id", std::string())},
{"type", "function"},
{"function", {
{"name", json_value(block, "name", std::string())},
{"arguments", json_value(block, "input", json::object()).dump()}
}}
});
has_tool_calls = true;
} else if (type == "tool_result") {
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
auto result_content = json_value(block, "content", json());
std::string result_text;
if (result_content.is_string()) {
result_text = result_content.get<std::string>();
} else if (result_content.is_array()) {
for (const auto & c : result_content) {
if (json_value(c, "type", std::string()) == "text") {
result_text += json_value(c, "text", std::string());
}
}
}
tool_results.push_back({
{"role", "tool"},
{"tool_call_id", tool_use_id},
{"content", result_text}
});
}
}
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
json new_msg = {{"role", role}};
if (!converted_content.empty()) {
new_msg["content"] = converted_content;
} else if (has_tool_calls || !reasoning_content.empty()) {
new_msg["content"] = "";
}
if (!tool_calls.empty()) {
new_msg["tool_calls"] = tool_calls;
}
if (!reasoning_content.empty()) {
new_msg["reasoning_content"] = reasoning_content;
}
oai_messages.push_back(new_msg);
}
for (const auto & tool_msg : tool_results) {
oai_messages.push_back(tool_msg);
}
}
}
oai_body["messages"] = oai_messages;
// Convert tools
if (body.contains("tools")) {
const json & tools = body.at("tools");
if (tools.is_array()) {
json oai_tools = json::array();
for (const auto & tool : tools) {
oai_tools.push_back({
{"type", "function"},
{"function", {
{"name", json_value(tool, "name", std::string())},
{"description", json_value(tool, "description", std::string())},
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
}}
});
}
oai_body["tools"] = oai_tools;
}
}
// Convert tool_choice
if (body.contains("tool_choice")) {
const json & tc = body.at("tool_choice");
if (tc.is_object()) {
std::string type = json_value(tc, "type", std::string());
if (type == "auto") {
oai_body["tool_choice"] = "auto";
} else if (type == "any" || type == "tool") {
oai_body["tool_choice"] = "required";
}
}
}
// Convert stop_sequences to stop
if (body.contains("stop_sequences")) {
oai_body["stop"] = body.at("stop_sequences");
}
// Handle max_tokens (required in Anthropic, but we're permissive)
if (body.contains("max_tokens")) {
oai_body["max_tokens"] = body.at("max_tokens");
} else {
oai_body["max_tokens"] = 4096;
}
// Pass through common params
for (const auto & key : {"temperature", "top_p", "top_k", "stream", "chat_template_kwargs"}) {
if (body.contains(key)) {
oai_body[key] = body.at(key);
}
}
// Handle Anthropic-specific thinking param
if (body.contains("thinking")) {
json thinking = json_value(body, "thinking", json::object());
std::string thinking_type = json_value(thinking, "type", std::string());
if (thinking_type == "enabled") {
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
oai_body["thinking_budget_tokens"] = budget_tokens;
}
}
// Handle Anthropic-specific metadata param
if (body.contains("metadata")) {
json metadata = json_value(body, "metadata", json::object());
std::string user_id = json_value(metadata, "user_id", std::string());
if (!user_id.empty()) {
oai_body["__metadata_user_id"] = user_id;
}
}
return oai_body;
}
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
json delta = json::object();
if (!diff.reasoning_content_delta.empty()) {
delta["reasoning_content"] = diff.reasoning_content_delta;
}
if (!diff.content_delta.empty()) {
delta["content"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
json tool_call;
tool_call["index"] = diff.tool_call_index;
if (!diff.tool_call_delta.id.empty()) {
tool_call["id"] = diff.tool_call_delta.id;
tool_call["type"] = "function";
}
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
json function = json::object();
if (!diff.tool_call_delta.name.empty()) {
function["name"] = diff.tool_call_delta.name;
}
if (!diff.tool_call_delta.arguments.empty()) {
function["arguments"] = diff.tool_call_delta.arguments;
}
tool_call["function"] = function;
}
delta["tool_calls"] = json::array({ tool_call });
}
return delta;
}
json convert_transcriptions_to_chatcmpl(
const json & inp_body,
const common_chat_templates * tmpls,
const std::map<std::string, uploaded_file> & in_files,
std::vector<raw_buffer> & out_files) {
// TODO @ngxson : this function may need to be improved in the future
// handle input files
out_files.clear();
auto it = in_files.find("file");
if (it != in_files.end()) {
out_files.push_back(it->second.data);
} else {
throw std::invalid_argument("No input file found for transcription");
}
// handle input data
std::string prompt = json_value(inp_body, "prompt", std::string());
std::string language = json_value(inp_body, "language", std::string());
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
if (response_format != "json") {
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
}
const common_chat_prompt_preset preset = common_chat_get_asr_prompt(tmpls);
if (prompt.empty()) {
prompt = preset.user;
}
if (!language.empty()) {
prompt += string_format(" (language: %s)", language.c_str());
}
prompt += get_media_marker();
json messages = json::array();
if (!preset.system.empty()) {
messages.push_back({{"role", "system"}, {"content", preset.system}});
}
messages.push_back({{"role", "user"}, {"content", prompt}});
json chatcmpl_body = inp_body; // copy all fields
chatcmpl_body["messages"] = messages;
// because input from form-data, everything is string, we need to correct the types here
std::string stream = json_value(inp_body, "stream", std::string("false"));
chatcmpl_body["stream"] = stream == "true";
if (inp_body.contains("max_tokens")) {
std::string inp = inp_body["max_tokens"].get<std::string>();
chatcmpl_body["max_tokens"] = std::stoul(inp);
}
if (inp_body.contains("temperature")) {
std::string inp = inp_body["temperature"].get<std::string>();
chatcmpl_body["temperature"] = std::stof(inp);
}
return chatcmpl_body;
}

View File

@@ -0,0 +1,26 @@
// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs)
#pragma once
#include "chat.h"
#include "server-common.h"
#include "server-http.h"
#include <nlohmann/json_fwd.hpp>
using json = nlohmann::ordered_json;
// Convert OpenAI Responses API format to OpenAI Chat Completions API format
json server_chat_convert_responses_to_chatcmpl(const json & body);
// Convert Anthropic Messages API format to OpenAI Chat Completions API format
json server_chat_convert_anthropic_to_oai(const json & body);
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
json convert_transcriptions_to_chatcmpl(
const json & body,
const common_chat_templates * tmpls,
const std::map<std::string, uploaded_file> & in_files,
std::vector<raw_buffer> & out_files);
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,373 @@
#pragma once
#include "common.h"
#include "log.h"
#include "llama.h"
#include "chat.h"
#include "mtmd.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <string>
#include <vector>
#include <cinttypes>
using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
using raw_buffer = std::vector<uint8_t>;
template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
return default_value;
}
} else {
return default_value;
}
}
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
ERROR_TYPE_AUTHENTICATION,
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
};
// thin wrapper around common_grammar_trigger with (de)serialization functions
struct server_grammar_trigger {
common_grammar_trigger value;
server_grammar_trigger() = default;
server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
server_grammar_trigger(const json & in) {
value.type = (common_grammar_trigger_type) in.at("type").get<int>();
value.value = in.at("value").get<std::string>();
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
value.token = (llama_token) in.at("token").get<int>();
}
}
json to_json() const {
json out {
{"type", (int) value.type},
{"value", value.value},
};
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) value.token;
}
return out;
}
};
json format_error_response(const std::string & message, const enum error_type type);
//
// random string / id
//
std::string random_string();
std::string gen_chatcmplid();
std::string gen_tool_call_id();
// get a random marker; note: each time the server restarts, the marker will be different
const char * get_media_marker();
//
// lora utils
//
// check whether the given lora set has only aloras activated (empty => false)
bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
// if the two sets of loras are different, they require a cache clear unless the
// change is only from aloras to aloras.
bool lora_should_clear_cache(
const std::vector<common_adapter_lora_info> & current,
const std::vector<common_adapter_lora_info> & next);
std::map<int, float> parse_lora_request(const json & data);
bool are_lora_equal(
const std::vector<common_adapter_lora_info> & l1,
const std::vector<common_adapter_lora_info> & l2);
// get the ids of all enabled loras
std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
//
// server_tokens
//
/**
* server_tokens is a helper to manage the input tokens and image for the server.
* it is made this way to simplify the logic of KV cache management.
*/
struct server_tokens {
bool has_mtmd = false;
private: // disallow accessing these members directly, risking out-of-sync
// map a **start** index in tokens to the image chunk
// note: the order need to be in-sync with tokens
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
// list of tokens
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
// otherwise, it is a normal text token
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
llama_tokens tokens;
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
// idx 0 1 2 3 4 5 6 7 8 9 10
// pos 0 1 2 3 4 5 5 5 7 7 7
// map_idx_to_media will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
~server_tokens() = default;
// Prevent copying
// TODO: server_tokens should be copyable - remove this:
server_tokens(const server_tokens&) = delete;
server_tokens& operator=(const server_tokens&) = delete;
// Allow moving (usually implicitly generated if members are movable)
server_tokens(server_tokens&&) = default;
server_tokens& operator=(server_tokens&&) = default;
// Allow accessing elements using [] operator
llama_token operator[](size_t index) { return tokens[index]; }
const llama_token& operator[](size_t index) const { return tokens[index]; }
server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
server_tokens(const llama_tokens & tokens, bool has_mtmd);
// for debugging
std::string str() const;
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;
// number of tokens with position < max_pos
size_t size_up_to_pos(llama_pos max_pos) const;
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
void push_back(llama_token tok);
// will create a copy of the chunk if it contains non-text data
void push_back(const mtmd_input_chunk * chunk);
// appends server tokens, updates the media map. copies media chunks.
void push_back(server_tokens & tokens);
// for compatibility with context shift and prompt truncation
void insert(const llama_tokens & inp_tokens);
// for compatibility with speculative decoding, ctx shift, slot save/load
const llama_tokens & get_tokens() const;
llama_tokens get_text_tokens() const;
// for compatibility with speculative decoding
void set_token(llama_pos pos, llama_token id);
size_t size() const { return tokens.size(); }
bool empty() const { return tokens.empty(); }
void clear() {
map_idx_to_media.clear();
tokens.clear();
}
void keep_first(size_t n);
std::string detokenize(const llama_context * ctx, bool special) const;
size_t get_common_prefix(const server_tokens & b) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
// encode and decode the image chunk
int32_t process_chunk(
llama_context * ctx,
mtmd_context * mctx,
size_t idx,
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const;
server_tokens clone() const;
};
//
// tokenizer and input processing utils
//
bool json_is_array_of_numbers(const json & data);
// is array having BOTH numbers & strings?
bool json_is_array_of_mixed_numbers_strings(const json & data);
// does array have any individual integers/tokens?
bool json_is_array_and_contains_numbers(const json & data);
// get value by path(key1 / key2)
json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
size_t validate_utf8(const std::string& text);
// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
*/
std::vector<server_tokens> tokenize_input_prompts(
const llama_vocab * vocab,
mtmd_context * mctx,
const json & json_prompt,
bool add_special,
bool parse_special);
//
// OAI utils
//
// global server parameters for chat formatting / parsing
struct server_chat_params {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
std::map<std::string, std::string> chat_template_kwargs; // mapping key --> json value
common_chat_templates_ptr tmpls;
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
int reasoning_budget = -1;
std::string reasoning_budget_message;
std::string media_path;
bool force_pure_content = false;
};
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const server_chat_params & opt,
std::vector<raw_buffer> & out_files);
// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(
const json & request,
const std::string & model_name,
const json & embeddings,
bool use_base64 = false);
// TODO: move it to server-task.cpp
json format_response_rerank(
const json & request,
const std::string & model_name,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts,
int top_n);
//
// other utils
//
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
std::string safe_json_to_str(const json & data);
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
// format incomplete utf-8 multibyte character for output
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_oai_sse(const json & data);
std::string format_oai_resp_sse(const json & data);
// format Anthropic-style SSE with event types
std::string format_anthropic_sse(const json & data);
bool is_valid_utf8(const std::string & str);
//
// formatting output responses
// TODO: move these to server-task.cpp
//
llama_tokens format_prompt_infill(
const llama_vocab * vocab,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt);
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
server_tokens format_prompt_rerank(
const struct llama_model * model,
const struct llama_vocab * vocab,
mtmd_context * mctx,
const std::string & query,
const std::string & doc);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,150 @@
#pragma once
#include "server-http.h"
#include "server-task.h"
#include "server-queue.h"
#include <nlohmann/json_fwd.hpp>
#include <cstddef>
#include <memory>
#include <set>
struct server_context_impl; // private implementation
struct server_context_meta {
std::string build_info;
std::string model_name;
std::set<std::string> model_aliases;
std::set<std::string> model_tags;
std::string model_path;
bool has_mtmd;
bool has_inp_image;
bool has_inp_audio;
json json_webui_settings;
int slot_n_ctx;
enum llama_pooling_type pooling_type;
// chat params
server_chat_params & chat_params;
std::map<std::string, bool> chat_template_caps;
// tokens
std::string bos_token_str;
std::string eos_token_str;
llama_token fim_pre_token;
llama_token fim_sub_token;
llama_token fim_mid_token;
llama_token fim_pad_token;
llama_token fim_rep_token;
llama_token fim_sep_token;
// sampling
std::vector<llama_logit_bias> logit_bias_eog;
// model meta
enum llama_vocab_type model_vocab_type;
int32_t model_vocab_n_tokens;
int32_t model_n_ctx_train;
int32_t model_n_embd_inp;
uint64_t model_n_params;
uint64_t model_size;
};
struct server_context {
std::unique_ptr<server_context_impl> impl;
server_context();
~server_context();
// load the model and initialize llama_context
// returns true on success
bool load_model(common_params & params);
// this function will block main thread until termination
void start_loop();
// terminate main loop (will unblock start_loop)
void terminate();
// get the underlaying llama_context, can return nullptr if sleeping
// not thread-safe, should only be used from the main thread
llama_context * get_llama_context() const;
// get a new response reader, used by CLI application
server_response_reader get_response_reader();
// get server metadata (read-only), can only be called after load_model()
// not thread-safe, should only be used from the main thread
server_context_meta get_meta() const;
// register a callback to be called when sleeping state changes
// must be set before load_model() is called
void on_sleeping_changed(std::function<void(bool)> callback);
};
// forward declarations
struct server_res_generator;
struct server_routes {
server_routes(const common_params & params, server_context & ctx_server);
void init_routes();
// note: this is not thread-safe and can only when ctx_http.is_ready is false
void update_meta(const server_context & ctx_server) {
this->meta = std::make_unique<server_context_meta>(ctx_server.get_meta());
}
// handlers using lambda function, so that they can capture `this` without `std::bind`
// they won't be called until ctx_http.is_ready is set to true
server_http_context::handler_t get_health;
server_http_context::handler_t get_metrics;
server_http_context::handler_t get_slots;
server_http_context::handler_t post_slots;
server_http_context::handler_t get_props;
server_http_context::handler_t post_props;
server_http_context::handler_t post_infill;
server_http_context::handler_t post_completions;
server_http_context::handler_t post_completions_oai;
server_http_context::handler_t post_chat_completions;
server_http_context::handler_t post_responses_oai;
server_http_context::handler_t post_transcriptions_oai;
server_http_context::handler_t post_anthropic_messages;
server_http_context::handler_t post_anthropic_count_tokens;
server_http_context::handler_t post_apply_template;
server_http_context::handler_t get_models;
server_http_context::handler_t post_tokenize;
server_http_context::handler_t post_detokenize;
server_http_context::handler_t post_embeddings;
server_http_context::handler_t post_embeddings_oai;
server_http_context::handler_t post_rerank;
server_http_context::handler_t get_lora_adapters;
server_http_context::handler_t post_lora_adapters;
// to be used in router mode
json get_model_info() const;
private:
std::unique_ptr<server_res_generator> handle_completions_impl(
const server_http_req & req,
server_task_type type,
const json & data,
const std::vector<raw_buffer> & files,
task_response_type res_type);
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
// using unique_ptr to allow late initialization of const
std::unique_ptr<const server_context_meta> meta;
const common_params & params;
const server_context_impl & ctx_server;
server_queue & queue_tasks;
server_response & queue_results;
std::unique_ptr<server_res_generator> create_response(bool bypass_sleep = false);
};

View File

@@ -0,0 +1,67 @@
#pragma once
#include "common.h"
#include "http.h"
#include <string>
#include <unordered_set>
#include <list>
#include <map>
#include "server-http.h"
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
std::string target_url = req.get_param("url");
common_http_url parsed_url = common_http_parse_url(target_url);
if (parsed_url.host.empty()) {
throw std::runtime_error("invalid target URL: missing host");
}
if (parsed_url.path.empty()) {
parsed_url.path = "/";
}
if (!parsed_url.password.empty()) {
throw std::runtime_error("authentication in target URL is not supported");
}
if (parsed_url.scheme != "http" && parsed_url.scheme != "https") {
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
}
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
std::map<std::string, std::string> headers;
for (auto [key, value] : req.headers) {
auto new_key = key;
if (string_starts_with(new_key, "x-proxy-header-")) {
string_replace_all(new_key, "x-proxy-header-", "");
}
headers[new_key] = value;
}
auto proxy = std::make_unique<server_http_proxy>(
method,
parsed_url.scheme,
parsed_url.host,
parsed_url.port,
parsed_url.path,
headers,
req.body,
req.files,
req.should_stop,
600, // timeout_read (default to 10 minutes)
600 // timeout_write (default to 10 minutes)
);
return proxy;
}
static server_http_context::handler_t proxy_handler_post = [](const server_http_req & req) -> server_http_res_ptr {
return proxy_request(req, "POST");
};
static server_http_context::handler_t proxy_handler_get = [](const server_http_req & req) -> server_http_res_ptr {
return proxy_request(req, "GET");
};

View File

@@ -0,0 +1,700 @@
#include "common.h"
#include "server-http.h"
#include "server-common.h"
#include <cpp-httplib/httplib.h>
#include <cstdlib>
#include <functional>
#include <future>
#include <string>
#include <thread>
#ifdef LLAMA_BUILD_WEBUI
// auto generated files (see README.md for details)
#include "index.html.hpp"
#include "bundle.js.hpp"
#include "bundle.css.hpp"
#include "loading.html.hpp"
#endif
//
// HTTP implementation using cpp-httplib
//
class server_http_context::Impl {
public:
std::unique_ptr<httplib::Server> srv;
};
server_http_context::server_http_context()
: pimpl(std::make_unique<server_http_context::Impl>())
{}
server_http_context::~server_http_context() = default;
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
// skip logging requests that are regularly sent, to avoid log spam
if (req.path == "/health"
|| req.path == "/v1/health"
|| req.path == "/models"
|| req.path == "/v1/models"
|| req.path == "/props"
|| req.path == "/metrics"
) {
return;
}
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
SRV_DBG("request: %s\n", req.body.c_str());
SRV_DBG("response: %s\n", res.body.c_str());
}
// For Google Cloud Platform deployment compatibility
struct gcp_params {
bool enabled;
std::string path_health;
std::string path_predict;
int port;
// Ref: https://docs.cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
gcp_params() {
enabled = getenv("AIP_MODE", "") == "PREDICTION";
path_health = getenv("AIP_HEALTH_ROUTE", "", true); // default: using the route defined in server.cpp
path_predict = getenv("AIP_PREDICT_ROUTE", "/predict", true);
port = std::stoi(getenv("AIP_HTTP_PORT", "8080"));
}
static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) {
const char * value = std::getenv(name);
if (value == nullptr || value[0] == '\0') {
return default_value;
}
std::string val = value;
if (ensure_leading_slash && !val.empty() && val[0] != '/') {
val.insert(val.begin(), '/');
}
return val;
}
};
bool server_http_context::init(const common_params & params) {
const gcp_params gcp;
path_prefix = params.api_prefix;
port = params.port;
hostname = params.hostname;
if (gcp.enabled) {
LOG_INF("%s: Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", __func__, gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
if (port != gcp.port) {
LOG_WRN("%s: Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", __func__, port, gcp.port);
}
port = gcp.port;
}
auto & srv = pimpl->srv;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
srv.reset(
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
);
} else {
LOG_INF("Running without SSL\n");
srv.reset(new httplib::Server());
}
#else
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
LOG_ERR("Server is built without SSL support\n");
return false;
}
srv.reset(new httplib::Server());
#endif
srv->set_default_headers({{"Server", "llama.cpp"}});
srv->set_logger(log_server_request);
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
// this is fail-safe; exceptions should already handled by `ex_wrapper`
std::string message;
try {
std::rethrow_exception(ep);
} catch (const std::exception & e) {
message = e.what();
} catch (...) {
message = "Unknown Exception";
}
res.status = 500;
res.set_content(message, "text/plain");
LOG_ERR("got exception: %s\n", message.c_str());
});
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "File Not Found"},
{"type", "not_found_error"},
{"code", 404}
}}
}),
"application/json; charset=utf-8"
);
}
// for other error codes, we skip processing here because it's already done by res->error()
});
// set timeouts and change hostname and port
srv->set_read_timeout (params.timeout_read);
srv->set_write_timeout(params.timeout_write);
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
if (reuse_port) {
#ifdef SO_REUSEPORT
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1);
#else
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
#endif
}
});
if (params.api_keys.size() == 1) {
auto key = params.api_keys[0];
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
} else if (params.api_keys.size() > 1) {
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
}
//
// Middlewares
//
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
"/models",
"/v1/models",
"/",
"/index.html",
"/bundle.js",
"/bundle.css",
};
// If API key is not set, skip validation
if (api_keys.empty()) {
return true;
}
// If path is public or static file, skip validation
if (public_endpoints.find(req.path) != public_endpoints.end()) {
return true;
}
// Check for API key in the Authorization header
std::string req_api_key = req.get_header_value("Authorization");
if (req_api_key.empty()) {
// retry with anthropic header
req_api_key = req.get_header_value("X-Api-Key");
}
// remove the "Bearer " prefix if needed
std::string prefix = "Bearer ";
if (req_api_key.substr(0, prefix.size()) == prefix) {
req_api_key = req_api_key.substr(prefix.size());
}
// validate the API key
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
return true; // API key is valid
}
// API key is invalid or not provided
res.status = 401;
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "Invalid API Key"},
{"type", "authentication_error"},
{"code", 401}
}}
}),
"application/json; charset=utf-8"
);
LOG_WRN("Unauthorized: Invalid API Key\n");
return false;
};
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
bool ready = is_ready.load();
if (!ready) {
#ifdef LLAMA_BUILD_WEBUI
auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") {
res.status = 503;
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
} else
#endif
{
// no endpoints is allowed to be accessed when the server is not ready
// this is to prevent any data races or inconsistent states
res.status = 503;
res.set_content(
safe_json_to_str(json {
{"error", {
{"message", "Loading model"},
{"type", "unavailable_error"},
{"code", 503}
}}
}),
"application/json; charset=utf-8"
);
}
return false;
}
return true;
};
// register server middlewares
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
if (req.method == "OPTIONS") {
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "GET, POST");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_content("", "text/html"); // blank response, no data
return httplib::Server::HandlerResponse::Handled; // skip further processing
}
if (!middleware_server_state(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
int n_threads_http = params.n_threads_http;
if (n_threads_http < 1) {
// +4 threads for monitoring, health and some threads reserved for MCP and other tasks in the future
n_threads_http = std::max(params.n_parallel + 4, (int32_t) std::thread::hardware_concurrency() - 1);
}
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
srv->new_task_queue = [n_threads_http] {
// spawn n_threads_http fixed thread (always alive), while allow up to 1024 max possible additional threads
// when n_threads_http is used, server will create new "dynamic" threads that will be destroyed after processing each request
// ref: https://github.com/yhirose/cpp-httplib/pull/2368
size_t max_threads = (size_t)n_threads_http + 1024;
return new httplib::ThreadPool(n_threads_http, max_threads);
};
//
// Web UI setup
//
if (!params.webui) {
LOG_INF("Web UI is disabled\n");
} else {
// register static assets routes
if (!params.public_path.empty()) {
// Set the base directory for serving static files
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
if (!is_found) {
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
return 1;
}
} else {
#ifdef LLAMA_BUILD_WEBUI
// using embedded static index.html
srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) {
// COEP and COOP headers, required by pyodide (python interpreter)
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
res.set_content(reinterpret_cast<const char*>(index_html), index_html_len, "text/html; charset=utf-8");
return false;
});
srv->Get(params.api_prefix + "/bundle.js", [](const httplib::Request & /*req*/, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(bundle_js), bundle_js_len, "application/javascript; charset=utf-8");
return false;
});
srv->Get(params.api_prefix + "/bundle.css", [](const httplib::Request & /*req*/, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(bundle_css), bundle_css_len, "text/css; charset=utf-8");
return false;
});
#endif
}
}
return true;
}
bool server_http_context::start() {
// Bind and listen
auto & srv = pimpl->srv;
bool was_bound = false;
bool is_sock = false;
if (string_ends_with(std::string(hostname), ".sock")) {
is_sock = true;
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
srv->set_address_family(AF_UNIX);
// bind_to_port requires a second arg, any value other than 0 should
// simply get ignored
was_bound = srv->bind_to_port(hostname, 8080);
} else {
LOG_INF("%s: binding port with default address family\n", __func__);
// bind HTTP listen port
if (port == 0) {
int bound_port = srv->bind_to_any_port(hostname);
was_bound = (bound_port >= 0);
if (was_bound) {
port = bound_port;
}
} else {
was_bound = srv->bind_to_port(hostname, port);
}
}
if (!was_bound) {
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
return false;
}
// run the HTTP server in a thread
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
srv->wait_until_ready();
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
: string_format("http://%s:%d", hostname.c_str(), port);
return true;
}
void server_http_context::stop() const {
if (pimpl->srv) {
pimpl->srv->stop();
}
}
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
for (const auto & [key, value] : headers) {
res.set_header(key, value);
}
}
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
std::map<std::string, std::string> params;
for (const auto & [key, value] : req.params) {
params[key] = value;
}
for (const auto & [key, value] : req.path_params) {
params[key] = value;
}
return params;
}
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
std::map<std::string, std::string> headers;
for (const auto & [key, value] : req.headers) {
headers[key] = value;
}
return headers;
}
static std::string build_query_string(const httplib::Request & req) {
std::string qs;
for (const auto & [key, value] : req.params) {
if (!qs.empty()) {
qs += '&';
}
qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value);
}
return qs;
}
// using unique_ptr for request to allow safe capturing in lambdas
using server_http_req_ptr = std::unique_ptr<server_http_req>;
static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
if (response->is_stream()) {
res.status = response->status;
set_headers(res, response->headers);
std::string content_type = response->content_type;
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
std::shared_ptr<server_http_req> q_ptr = std::move(request);
std::shared_ptr<server_http_res> r_ptr = std::move(response);
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
std::string chunk;
bool has_next = response->next(chunk);
if (!chunk.empty()) {
if (!sink.write(chunk.data(), chunk.size())) {
return false;
}
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
}
if (!has_next) {
sink.done();
SRV_DBG("%s", "http: stream ended\n");
}
return has_next;
};
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
response.reset(); // trigger the destruction of the response object
request.reset(); // trigger the destruction of the request object
};
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
} else {
res.status = response->status;
set_headers(res, response->headers);
res.set_content(response->data, response->content_type);
}
}
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
{},
req.is_connection_closed
});
server_http_res_ptr response = handler(*request);
process_handler_response(std::move(request), response, res);
});
}
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
std::string body = req.body;
std::map<std::string, uploaded_file> files;
if (req.is_multipart_form_data()) {
// translate text fields to a JSON object and use it as the body
json form_json = json::object();
for (const auto & [key, field] : req.form.fields) {
if (form_json.contains(key)) {
// if the key already exists, convert it to an array
if (!form_json[key].is_array()) {
json existing_value = form_json[key];
form_json[key] = json::array({existing_value});
}
form_json[key].push_back(field.content);
} else {
form_json[key] = field.content;
}
}
body = form_json.dump();
// populate files from multipart form
for (const auto & [key, file] : req.form.files) {
files[key] = uploaded_file{
raw_buffer(file.content.begin(), file.content.end()),
file.filename,
file.content_type,
};
}
}
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
body,
std::move(files),
req.is_connection_closed
});
server_http_res_ptr response = handler(*request);
process_handler_response(std::move(request), response, res);
});
}
//
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
//
// Derives the camelCase @requestFormat alias for a registered path.
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
static std::string path_to_gcp_format(const std::string & path) {
std::string s = path;
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
s = s.substr(3);
}
if (!s.empty() && s[0] == '/') {
s = s.substr(1);
}
std::string result;
bool cap = false;
for (unsigned char c : s) {
if (c == ':') break; // stop before path parameters
if (c == '/' || c == '-' || c == '_') {
cap = true;
} else {
result += cap ? (char)std::toupper(c) : (char)c;
cap = false;
}
}
return result;
}
static json parse_gcp_predict_response(const server_http_res_ptr & res) {
if (res == nullptr) {
throw std::runtime_error("empty response from internal handler");
}
if (res->is_stream()) {
throw std::invalid_argument("predict route does not support streaming responses");
}
if (res->data.empty()) {
return nullptr;
}
try {
return json::parse(res->data);
} catch (...) {
return res->data;
}
}
void server_http_context::register_gcp_compat() {
const gcp_params gcp;
if (!gcp.enabled) {
// do nothing
return;
}
if (handlers.count(gcp.path_predict)) {
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, gcp.path_predict.c_str());
exit(1);
}
// camelCase alias -> canonical path (first registration wins on collision)
// e.g. "chatCompletions" -> "/v1/chat/completions"
std::unordered_map<std::string, std::string> alias_to_path;
for (const auto & [path, _] : handlers) {
alias_to_path.emplace(path_to_gcp_format(path), path);
}
if (!gcp.path_health.empty()) {
auto health_handler = handlers.find("/health");
GGML_ASSERT(health_handler != handlers.end());
get(gcp.path_health, health_handler->second);
}
post(gcp.path_predict, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
static const auto build_error = [](const std::string & message, error_type type) -> json {
return json {{"error", format_error_response(message, type)}};
};
json data;
try {
data = json::parse(req.body);
} catch (const std::exception & e) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
if (!data.is_object()) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("request body must be a JSON object", ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
if (!data.contains("instances") || !data.at("instances").is_array()) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("request body must include an array field named instances", ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
const json & instances = data.at("instances");
static const size_t MAX_INSTANCES = 128;
if (instances.size() > MAX_INSTANCES) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("instances array exceeds maximum size of " + std::to_string(MAX_INSTANCES), ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
std::vector<std::future<json>> futures;
futures.reserve(instances.size());
for (const auto & instance : instances) {
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
if (!instance.is_object()) {
return build_error("each instance must be a JSON object", ERROR_TYPE_INVALID_REQUEST);
}
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
return build_error("each instance must include a string @requestFormat", ERROR_TYPE_INVALID_REQUEST);
}
try {
json payload = instance;
const std::string format = payload.at("@requestFormat").get<std::string>();
payload.erase("@requestFormat");
if (payload.contains("stream")) {
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
payload["stream"] = false;
}
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
std::string dispatch_path;
auto it_alias = alias_to_path.find(format);
if (it_alias != alias_to_path.end()) {
dispatch_path = it_alias->second;
} else if (handlers.count(format)) {
dispatch_path = format;
} else {
return build_error("no handler registered for @requestFormat: " + format, ERROR_TYPE_INVALID_REQUEST);
}
const server_http_req internal_req {
req.params,
req.headers,
path_prefix + dispatch_path,
req.query_string,
payload.dump(),
{},
req.should_stop,
};
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
return parse_gcp_predict_response(internal_res);
} catch (const std::invalid_argument & e) {
return build_error(e.what(), ERROR_TYPE_INVALID_REQUEST);
} catch (const std::exception & e) {
return build_error(e.what(), ERROR_TYPE_SERVER);
} catch (...) {
return build_error("unknown error", ERROR_TYPE_SERVER);
}
}));
}
json predictions = json::array();
for (auto & future : futures) {
predictions.push_back(future.get());
}
auto res = std::make_unique<server_http_res>();
res->data = safe_json_to_str({{"predictions", predictions}});
return res;
});
}

View File

@@ -0,0 +1,94 @@
#pragma once
#include <atomic>
#include <functional>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <cstdint>
struct common_params;
// generator-like API for HTTP response generation
// this object response with one of the 2 modes:
// 1) normal response: `data` contains the full response body
// 2) streaming response: each call to next(output) generates the next chunk
// when next(output) returns false, no more data after the current chunk
// note: some chunks can be empty, in which case no data is sent for that chunk
struct server_http_res {
std::string content_type = "application/json; charset=utf-8";
int status = 200;
std::string data;
std::map<std::string, std::string> headers;
// TODO: move this to a virtual function once we have proper polymorphism support
std::function<bool(std::string &)> next = nullptr;
bool is_stream() const {
return next != nullptr;
}
virtual ~server_http_res() = default;
};
// unique pointer, used by set_chunked_content_provider
// httplib requires the stream provider to be stored in heap
using server_http_res_ptr = std::unique_ptr<server_http_res>;
using raw_buffer = std::vector<uint8_t>;
struct uploaded_file {
raw_buffer data;
std::string filename;
std::string content_type;
};
struct server_http_req {
std::map<std::string, std::string> params; // path_params + query_params
std::map<std::string, std::string> headers; // used by MCP proxy
std::string path;
std::string query_string; // query parameters string (e.g. "action=save")
std::string body;
std::map<std::string, uploaded_file> files; // used for file uploads (form data)
const std::function<bool()> & should_stop;
std::string get_param(const std::string & key, const std::string & def = "") const {
auto it = params.find(key);
if (it != params.end()) {
return it->second;
}
return def;
}
};
struct server_http_context {
class Impl;
std::unique_ptr<Impl> pimpl;
std::thread thread; // server thread
std::atomic<bool> is_ready = false;
// note: the handler should never throw exceptions
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
mutable std::unordered_map<std::string, handler_t> handlers;
std::string path_prefix;
std::string hostname;
int port;
server_http_context();
~server_http_context();
bool init(const common_params & params);
bool start();
void stop() const;
void get(const std::string & path, const handler_t & handler) const;
void post(const std::string & path, const handler_t & handler) const;
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
// Must be called AFTER all other API routes are registered
void register_gcp_compat();
// for debugging
std::string listening_address;
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,233 @@
#pragma once
#include "common.h"
#include "preset.h"
#include "server-common.h"
#include "server-http.h"
#include <mutex>
#include <condition_variable>
#include <functional>
#include <memory>
#include <set>
/**
* state diagram:
*
* UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING
* ▲ │ │ ▲
* └───failed───┘ │ │
* ▲ └──sleeping─────┘
* └────────unloaded─────────┘
*/
enum server_model_status {
// TODO: also add downloading state when the logic is added
SERVER_MODEL_STATUS_UNLOADED,
SERVER_MODEL_STATUS_LOADING,
SERVER_MODEL_STATUS_LOADED,
SERVER_MODEL_STATUS_SLEEPING
};
static server_model_status server_model_status_from_string(const std::string & status_str) {
if (status_str == "unloaded") {
return SERVER_MODEL_STATUS_UNLOADED;
}
if (status_str == "loading") {
return SERVER_MODEL_STATUS_LOADING;
}
if (status_str == "loaded") {
return SERVER_MODEL_STATUS_LOADED;
}
if (status_str == "sleeping") {
return SERVER_MODEL_STATUS_SLEEPING;
}
throw std::runtime_error("invalid server model status");
}
static std::string server_model_status_to_string(server_model_status status) {
switch (status) {
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
case SERVER_MODEL_STATUS_LOADING: return "loading";
case SERVER_MODEL_STATUS_LOADED: return "loaded";
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
default: return "unknown";
}
}
struct server_model_meta {
common_preset preset;
std::string name;
std::set<std::string> aliases; // additional names that resolve to this model
std::set<std::string> tags; // informational tags, not used for routing
int port = 0;
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
bool is_ready() const {
return status == SERVER_MODEL_STATUS_LOADED;
}
bool is_running() const {
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING || status == SERVER_MODEL_STATUS_SLEEPING;
}
bool is_failed() const {
return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
}
void update_args(common_preset_context & ctx_presets, std::string bin_path);
};
struct subprocess_s;
struct server_models {
private:
struct instance_t {
std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
std::thread th;
server_model_meta meta;
FILE * stdin_file = nullptr;
};
std::mutex mutex;
std::condition_variable cv;
std::map<std::string, instance_t> mapping;
// for stopping models
std::condition_variable cv_stop;
std::set<std::string> stopping_models;
// set to true while load_models() is executing a reload; load() will wait until clear
bool is_reloading = false;
common_preset_context ctx_preset;
common_params base_params;
std::string bin_path;
std::vector<std::string> base_env;
common_preset base_preset; // base preset from llama-server CLI args
void update_meta(const std::string & name, const server_model_meta & meta);
// unload least recently used models if the limit is reached
void unload_lru();
// not thread-safe, caller must hold mutex
void add_model(server_model_meta && meta);
public:
server_models(const common_params & params, int argc, char ** argv);
// (re-)load the list of models from various sources and prepare the metadata mapping
// - if this is called the first time, simply populate the metadata
// - if this is called subsequently (e.g. when refreshing from disk):
// - if a model is running but updated or removed from the source, it will be unloaded
// - if a model is not running, it will be added or updated according to the source
void load_models();
// check if a model instance exists (thread-safe)
bool has_model(const std::string & name);
// return a copy of model metadata (thread-safe)
std::optional<server_model_meta> get_meta(const std::string & name);
// return a copy of all model metadata (thread-safe)
std::vector<server_model_meta> get_all_meta();
// load and unload model instances
// these functions are thread-safe
void load(const std::string & name);
void unload(const std::string & name);
void unload_all();
// update the status of a model instance (thread-safe)
void update_status(const std::string & name, server_model_status status, int exit_code);
void update_loaded_info(const std::string & name, std::string & raw_info);
// wait until the model instance is fully loaded (thread-safe)
// return when the model no longer in "loading" state
void wait_until_loading_finished(const std::string & name);
// ensure the model is in ready state (thread-safe)
// return false if model is ready
// otherwise, load the model and blocking wait until it's ready, then return true (meta may need to be refreshed)
bool ensure_model_ready(const std::string & name);
// proxy an HTTP request to the model instance
server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
// return true if the current process is a child server instance
static bool is_child_server();
// notify the router server that a model instance is ready
// return the monitoring thread (to be joined by the caller)
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler, const json & model_info);
// notify the router server that the sleeping state has changed
static void notify_router_sleeping_state(bool sleeping);
};
struct server_models_routes {
common_params params;
json webui_settings = json::object();
server_models models;
server_models_routes(const common_params & params, int argc, char ** argv)
: params(params), models(params, argc, argv) {
if (!this->params.webui_config_json.empty()) {
try {
webui_settings = json::parse(this->params.webui_config_json);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
throw;
}
}
init_routes();
}
void init_routes();
// handlers using lambda function, so that they can capture `this` without `std::bind`
server_http_context::handler_t get_router_props;
server_http_context::handler_t proxy_get;
server_http_context::handler_t proxy_post;
server_http_context::handler_t get_router_models;
server_http_context::handler_t post_router_models_load;
server_http_context::handler_t post_router_models_unload;
};
/**
* A simple HTTP proxy that forwards requests to another server
* and relays the responses back.
*/
struct server_http_proxy : server_http_res {
std::function<void()> cleanup = nullptr;
public:
server_http_proxy(const std::string & method,
const std::string & scheme,
const std::string & host,
int port,
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::map<std::string, uploaded_file> & files,
const std::function<bool()> should_stop,
int32_t timeout_read,
int32_t timeout_write
);
~server_http_proxy() {
if (cleanup) {
cleanup();
}
}
private:
std::thread thread;
struct msg_t {
std::map<std::string, std::string> headers;
int status = 0;
std::string data;
std::string content_type;
};
};

View File

@@ -0,0 +1,451 @@
#include "server-task.h"
#include "server-queue.h"
#include "log.h"
#include <chrono>
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
//
// server_queue
//
int server_queue::post(server_task && task, bool front) {
std::unique_lock<std::mutex> lock(mutex_tasks);
GGML_ASSERT(task.id != -1);
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
const int task_id = task.id;
QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
return task_id;
}
int server_queue::post(std::vector<server_task> && tasks, bool front) {
std::unique_lock<std::mutex> lock(mutex_tasks);
for (auto & task : tasks) {
if (task.id == -1) {
task.id = id++;
}
// if this is cancel task make sure to clean up pending tasks
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
queue_tasks.push_back(std::move(task));
}
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
return 0;
}
void server_queue::defer(server_task && task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
QUE_DBG("defer task, id = %d\n", task.id);
queue_tasks_deferred.push_back(std::move(task));
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
}
int server_queue::get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
int new_id = id++;
return new_id;
}
void server_queue::pop_deferred_task(int id_slot) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!queue_tasks_deferred.empty()) {
// try to find a task that uses the specified slot
bool found = false;
for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
if (it->id_slot == id_slot) {
QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
queue_tasks.emplace_front(std::move(*it));
queue_tasks_deferred.erase(it);
found = true;
break;
}
}
// if not tasks found using the slot, just pop the first deferred task (default behavior)
if (!found) {
QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
queue_tasks_deferred.pop_front();
}
}
time_last_task = ggml_time_ms();
condition_tasks.notify_one();
}
void server_queue::wait_until_no_sleep() {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!sleeping) {
return;
} else {
if (!req_stop_sleeping) {
QUE_DBG("%s", "requesting to stop sleeping\n");
req_stop_sleeping = true;
condition_tasks.notify_one(); // only main thread is waiting on this
}
QUE_DBG("%s", "waiting until no sleep\n");
condition_tasks.wait(lock, [&]{
return !sleeping;
});
}
}
void server_queue::terminate() {
std::unique_lock<std::mutex> lock(mutex_tasks);
running = false;
condition_tasks.notify_all();
}
void server_queue::start_loop(int64_t idle_sleep_ms) {
running = true;
time_last_task = ggml_time_ms();
constexpr auto max_wait_time = std::chrono::seconds(1);
auto should_sleep = [&]() -> bool {
// caller must hold mutex_tasks
if (idle_sleep_ms < 0) {
return false;
}
int64_t now = ggml_time_ms();
return (now - time_last_task) >= idle_sleep_ms;
};
while (true) {
QUE_DBG("%s", "processing new tasks\n");
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
lock.unlock();
break;
}
server_task task = std::move(queue_tasks.front());
queue_tasks.pop_front();
lock.unlock();
QUE_DBG("processing task, id = %d\n", task.id);
callback_new_task(std::move(task));
}
// all tasks in the current loop is processed, slots data is now ready
QUE_DBG("%s", "update slots\n");
// this will run the main inference process for all slots
callback_update_slots();
{
// update_slots() may take a while to finish, we need to make sure it's not counted as idle
std::unique_lock<std::mutex> lock(mutex_tasks);
time_last_task = ggml_time_ms();
}
QUE_DBG("%s", "waiting for new tasks\n");
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running || !queue_tasks.empty()) {
break; // go back to process new tasks or terminate
}
// no tasks, check for sleeping state
if (should_sleep()) {
QUE_INF("%s", "entering sleeping state\n");
sleeping = true;
callback_sleeping_state(true);
req_stop_sleeping = false;
// wait until we are requested to exit sleeping state
condition_tasks.wait(lock, [&]{
return (!running || req_stop_sleeping);
});
if (!running) { // may changed during sleep
break; // terminate
}
QUE_INF("%s", "exiting sleeping state\n");
req_stop_sleeping = false;
callback_sleeping_state(false);
sleeping = false;
time_last_task = ggml_time_ms();
condition_tasks.notify_all(); // notify wait_until_no_sleep()
break; // process new tasks
} else {
// wait for new tasks or timeout for checking sleeping condition
bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
return (!queue_tasks.empty() || !running);
});
if (res) {
break; // new task arrived or terminate
}
// otherwise, loop again to check sleeping condition
}
}
}
}
void server_queue::cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
auto rm_func = [id_target](const server_task & task) {
return task.id == id_target;
};
queue_tasks.erase(
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
queue_tasks.end());
queue_tasks_deferred.erase(
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
queue_tasks_deferred.end());
}
//
// server_response
//
void server_response::add_waiting_task_id(int id_task) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.insert(id_task);
}
void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : id_tasks) {
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.insert(id_task);
}
}
void server_response::remove_waiting_task_id(int id_task) {
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock<std::mutex> lock(mutex_results);
waiting_task_ids.erase(id_task);
// make sure to clean up all pending results
queue_results.erase(
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
return res->id == id_task;
}),
queue_results.end());
}
void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : id_tasks) {
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.erase(id_task);
}
}
server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
if (!running) {
RES_DBG("%s : queue result stop\n", "recv");
std::terminate(); // we cannot return here since the caller is HTTP code
}
return !queue_results.empty();
});
for (size_t i = 0; i < queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
if (!running) {
RES_DBG("%s : queue result stop\n", __func__);
std::terminate(); // we cannot return here since the caller is HTTP code
}
if (cr_res == std::cv_status::timeout) {
return nullptr;
}
}
// should never reach here
}
server_task_result_ptr server_response::recv(int id_task) {
std::unordered_set<int> id_tasks = {id_task};
return recv(id_tasks);
}
void server_response::send(server_task_result_ptr && result) {
RES_DBG("sending result for task id = %d\n", result->id);
std::unique_lock<std::mutex> lock(mutex_results);
for (const auto & id_task : waiting_task_ids) {
if (result->id == id_task) {
RES_DBG("task id = %d pushed to result queue\n", result->id);
queue_results.emplace_back(std::move(result));
condition_results.notify_all();
return;
}
}
}
void server_response::terminate() {
running = false;
condition_results.notify_all();
}
//
// server_response_reader
//
void server_response_reader::post_task(server_task && task, bool front) {
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
task.index = 0;
id_tasks.insert(task.id);
states.push_back(task.create_state());
queue_results.add_waiting_task_id(task.id);
queue_tasks.post(std::move(task), front);
}
void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks);
states.reserve(tasks.size());
size_t index = 0;
for (auto & task : tasks) {
task.index = index++;
states.push_back(task.create_state());
// for child tasks
for (auto & child_task : task.child_tasks) {
child_task.index = index++;
states.push_back(child_task.create_state());
}
}
GGML_ASSERT(states.size() == id_tasks.size());
queue_results.add_waiting_task_ids(id_tasks);
queue_tasks.post(std::move(tasks), front);
}
bool server_response_reader::has_next() const {
return !cancelled && received_count < id_tasks.size();
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
SRV_WRN("%s", "stopping wait for next result due to should_stop condition (adjust the --timeout argument if needed)\n");
SRV_WRN("%s", "ref: https://github.com/ggml-org/llama.cpp/pull/22907\n");
return nullptr;
}
} else {
if (result->is_error()) {
stop(); // cancel remaining tasks
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (!states.empty()) {
// update the generation state if needed
const size_t idx = result->index;
GGML_ASSERT(idx < states.size());
result->update(states[idx]);
}
if (result->is_stop()) {
received_count++;
}
return result;
}
}
// should not reach here
}
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
batch_response batch_res;
batch_res.results.clear();
batch_res.results.resize(id_tasks.size());
while (has_next()) {
auto res = next(should_stop);
if (res == nullptr) {
batch_res.is_terminated = true;
return batch_res;
}
if (res->is_error()) {
batch_res.error = std::move(res);
return batch_res;
}
const size_t idx = res->index;
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
batch_res.results[idx] = std::move(res);
}
return batch_res;
}
void server_response_reader::stop() {
queue_results.remove_waiting_task_ids(id_tasks);
if (has_next() && !cancelled) {
// if tasks is not finished yet, cancel them
cancelled = true;
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
queue_results.remove_waiting_task_id(id_task);
cancel_tasks.push_back(std::move(task));
}
// push to beginning of the queue, so it has highest priority
queue_tasks.post(std::move(cancel_tasks), true);
} else {
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
}
}

205
tools/server/server-queue.h Normal file
View File

@@ -0,0 +1,205 @@
#pragma once
#include "server-task.h"
#include <condition_variable>
#include <deque>
#include <mutex>
#include <vector>
#include <unordered_set>
// struct for managing server tasks
// in most cases, use server_response_reader to post new tasks and retrieve results
struct server_queue {
private:
int id = 0;
bool running = false;
bool sleeping = false;
bool req_stop_sleeping = false;
int64_t time_last_task = 0;
// queues
std::deque<server_task> queue_tasks;
std::deque<server_task> queue_tasks_deferred;
std::mutex mutex_tasks;
std::condition_variable condition_tasks;
// callback functions
std::function<void(server_task &&)> callback_new_task;
std::function<void(void)> callback_update_slots;
std::function<void(bool)> callback_sleeping_state;
public:
// Add a new task to the end of the queue
int post(server_task && task, bool front = false);
// multi-task version of post()
int post(std::vector<server_task> && tasks, bool front = false);
// Add a new task, but defer until one slot is available
void defer(server_task && task);
// Get the next id for creating a new task
int get_new_id();
// Call when the state of one slot is changed, it will move one task from deferred to main queue
// prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
void pop_deferred_task(int id_slot);
// if sleeping, request exiting sleep state and wait until it is done
// returns immediately if not sleeping
void wait_until_no_sleep();
bool is_sleeping() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return sleeping;
}
// end the start_loop routine
void terminate();
/**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Update all slots
*
* Sleeping procedure (disabled if idle_sleep_ms < 0):
* - If there is no task after idle_sleep_ms, enter sleeping state
* - Call callback_sleeping_state(true)
* - Wait until req_stop_sleeping is set to true
* - Call callback_sleeping_state(false)
* - Exit sleeping state
*/
void start_loop(int64_t idle_sleep_ms = -1);
// for metrics
size_t queue_tasks_deferred_size() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return queue_tasks_deferred.size();
}
//
// Functions below are not thread-safe, must only be used before start_loop() is called
//
// Register function to process a new task
void on_new_task(std::function<void(server_task &&)> callback) {
callback_new_task = std::move(callback);
}
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
}
// Register callback for sleeping state change; multiple callbacks are allowed
// note: when entering sleeping state, the callback is called AFTER sleeping is set to true
// when leaving sleeping state, the callback is called BEFORE sleeping is set to false
void on_sleeping_state(std::function<void(bool)> callback) {
if (callback_sleeping_state) {
auto prev_callback = std::move(callback_sleeping_state);
callback_sleeping_state = [prev_callback, callback](bool sleeping) {
prev_callback(sleeping);
callback(sleeping);
};
} else {
callback_sleeping_state = std::move(callback);
}
}
private:
void cleanup_pending_task(int id_target);
};
// struct for managing server responses
// in most cases, use server_response_reader to retrieve results
struct server_response {
private:
bool running = true;
// for keeping track of all tasks waiting for the result
std::unordered_set<int> waiting_task_ids;
// the main result queue (using ptr for polymorphism)
std::vector<server_task_result_ptr> queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
public:
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task);
void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task);
// remove multiple tasks from waiting list
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
// This function blocks the thread until there is a response for one of the id_tasks
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
// single-task version of recv()
server_task_result_ptr recv(int id_task);
// Send a new result to a waiting id_task
void send(server_task_result_ptr && result);
// terminate the waiting loop
void terminate();
};
// utility class to make working with server_queue and server_response easier
// it provides a generator-like API for server responses
// support pooling connection state and aggregating multiple results
struct server_response_reader {
std::unordered_set<int> id_tasks;
server_queue & queue_tasks;
server_response & queue_results;
size_t received_count = 0;
bool cancelled = false;
int polling_interval_seconds;
// tracking generation state and partial tool calls
// only used by streaming completions
std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}
int get_new_id() {
return queue_tasks.get_new_id();
}
// if front = true, the task will be posted to the front of the queue (high priority)
void post_task(server_task && task, bool front = false);
void post_tasks(std::vector<server_task> && tasks, bool front = false);
bool has_next() const;
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop);
struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
std::vector<server_task_result_ptr> results;
server_task_result_ptr error; // nullptr if no error
};
// aggregate multiple results
batch_response wait_for_all(const std::function<bool()> & should_stop);
void stop();
};

2155
tools/server/server-task.cpp Normal file

File diff suppressed because it is too large Load Diff

632
tools/server/server-task.h Normal file
View File

@@ -0,0 +1,632 @@
#pragma once
#include "common.h"
#include "llama.h"
#include <string>
#include <unordered_set>
#include <list>
#include <map>
// TODO: prevent including the whole server-common.h as we only use server_tokens
#include "server-common.h"
using json = nlohmann::ordered_json;
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_EMBEDDING,
SERVER_TASK_TYPE_RERANK,
SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_GET_LORA,
SERVER_TASK_TYPE_SET_LORA,
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum task_response_type {
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
TASK_RESPONSE_TYPE_OAI_CHAT,
TASK_RESPONSE_TYPE_OAI_CMPL,
TASK_RESPONSE_TYPE_OAI_RESP,
TASK_RESPONSE_TYPE_OAI_ASR, // transcriptions API
TASK_RESPONSE_TYPE_OAI_EMBD,
TASK_RESPONSE_TYPE_ANTHROPIC,
};
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
STOP_TYPE_WORD,
STOP_TYPE_LIMIT,
};
struct task_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int32_t n_cmpl = 1; // number of completions to generate from this prompt
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::map<int, float> lora; // mapping adapter ID -> scale
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
};
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_parser_params chat_parser_params;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
std::unordered_set<size_t> sent_tool_call_names;
// for OpenAI Responses and Anthropic streaming API:
// track output item / content block state across chunks
bool thinking_block_started = false;
bool text_block_started = false;
// for OpenAI Responses streaming API
const std::string oai_resp_id;
const std::string oai_resp_reasoning_id;
const std::string oai_resp_message_id;
std::string oai_resp_fc_id; // function call ID for current args delta
task_result_state(const common_chat_parser_params & chat_parser_params)
: chat_parser_params(chat_parser_params)
, oai_resp_id("resp_" + random_string())
, oai_resp_reasoning_id("rs_" + random_string())
, oai_resp_message_id("msg_" + random_string()) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs,
bool filter_tool_calls = false);
};
struct server_task {
int id = -1; // to be filled by server_queue
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
size_t index = 0; // used when there are multiple prompts (batch request)
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
// used by parallel sampling (multiple completions from same prompt)
int id_parent = -1;
// temporary store of child tasks for scheduling
// note: accessing to elements is invalid after the task is moved to server_slot
std::vector<server_task> child_tasks;
// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
server_tokens tokens;
// only used by CLI, this allow tokenizing CLI inputs on server side
// we need this because mtmd_context and vocab are not accessible outside of server_context
bool cli = false;
std::string cli_prompt;
std::vector<raw_buffer> cli_files;
server_task_type type;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
int id_slot;
std::string filename;
std::string filepath;
};
slot_action slot_action;
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::map<int, float> set_lora; // mapping adapter ID -> scale
server_task() = default;
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
bool need_embd() const {
switch (type) {
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
return true;
default:
return false;
}
}
bool need_logits() const {
switch (type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
bool need_sampling() const {
switch (type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
static task_params params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data);
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
for (auto & child : tasks[i].child_tasks) {
ids.insert(child.id);
}
}
return ids;
}
void add_child(int id_parent, int id_child) {
server_task copy;
copy.id = id_child;
copy.id_parent = id_parent;
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();
copy.id_slot = -1; // child tasks cannot specify slot
// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
}
child_tasks.push_back(std::move(copy));
}
// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.chat_parser_params);
}
bool is_parent() const {
return child_tasks.size() > 0;
}
bool is_child() const {
return id_parent != -1;
}
};
struct result_timings {
int32_t cache_n = -1;
int32_t prompt_n = -1;
double prompt_ms = 0.0;
double prompt_per_token_ms = 0.0;
double prompt_per_second = 0.0;
int32_t predicted_n = -1;
double predicted_ms = 0.0;
double predicted_per_token_ms = 0.0;
double predicted_per_second = 0.0;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;
json to_json() const;
};
struct result_prompt_progress {
int32_t total = 0;
int32_t cache = 0;
int32_t processed = 0;
int64_t time_ms = 0;
json to_json() const;
};
struct server_task_result {
int id = -1;
int id_slot = -1;
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
size_t index = 0; // to be used for batched tasks
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
return true;
}
virtual void update(task_result_state &) {
// only used by server_task_result_cmpl_*
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct prob_info {
llama_token tok;
std::string txt;
float prob;
};
std::vector<prob_info> probs;
json to_json(bool post_sampling_probs) const;
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
static float logarithm(float x);
static std::vector<unsigned char> str_to_bytes(const std::string & str);
};
struct server_task_result_cmpl_final : server_task_result {
std::string content;
llama_tokens tokens;
bool stream;
bool include_usage;
result_timings timings;
std::string prompt;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_prompt_tokens_cache;
int32_t n_tokens_cached;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<std::string> response_fields;
task_params generation_params;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg; // to be populated by update()
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
// for OpenAI Responses API
std::string oai_resp_id;
std::string oai_resp_reasoning_id;
std::string oai_resp_message_id;
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
oai_resp_id = state.oai_resp_id;
oai_resp_reasoning_id = state.oai_resp_reasoning_id;
oai_resp_message_id = state.oai_resp_message_id;
}
json to_json_non_oaicompat();
json usage_json_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
json to_json_oaicompat_resp();
json to_json_oaicompat_resp_stream();
json to_json_oaicompat_asr();
json to_json_anthropic();
json to_json_anthropic_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
std::string content;
llama_tokens tokens;
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_prompt_tokens_cache;
bool post_sampling_probs;
bool is_progress = false;
completion_token_output prob_output;
result_timings timings;
result_prompt_progress progress;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
// Streaming state copied from task_result_state for this chunk
bool thinking_block_started = false;
bool text_block_started = false;
// for OpenAI Responses API
std::string oai_resp_id;
std::string oai_resp_reasoning_id;
std::string oai_resp_message_id;
std::string oai_resp_fc_id;
// for Anthropic API: track if any reasoning content has been generated
bool anthropic_has_reasoning = false;
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
virtual void update(task_result_state & state) override;
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_oaicompat_resp();
json to_json_oaicompat_asr();
json to_json_anthropic();
};
struct server_task_result_embd : server_task_result {
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// response formatting
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
};
struct server_task_result_rerank : server_task_result {
float score = -1e6;
int32_t n_tokens;
virtual json to_json() override;
};
struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;
virtual bool is_error() override {
return true;
}
virtual json to_json() override;
};
struct server_task_result_metrics : server_task_result {
int n_idle_slots;
int n_processing_slots;
int n_tasks_deferred;
int64_t t_start;
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
// therefore, we use json to temporarily store the slot.to_json() result
json slots_data = json::array();
virtual json to_json() override;
};
struct server_task_result_slot_save_load : server_task_result {
std::string filename;
bool is_save; // true = save, false = load
size_t n_tokens;
size_t n_bytes;
double t_ms;
virtual json to_json() override;
};
struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override;
};
struct server_task_result_get_lora : server_task_result {
struct lora {
common_adapter_lora_info info;
std::string alora_invocation_string;
llama_tokens alora_invocation_tokens;
};
std::vector<lora> loras;
virtual json to_json() override;
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> drft;
size_t size() const {
return main.size() + drft.size();
}
};
struct server_prompt {
server_tokens tokens;
server_prompt_data data;
std::list<common_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = 0;
res += data.size();
for (const auto & ckpt : checkpoints) {
res += ckpt.size();
}
return res;
}
int n_tokens() const {
return tokens.size();
}
server_prompt clone() const {
return server_prompt {
tokens.clone(),
data,
checkpoints,
};
}
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
std::list<server_prompt> states;
// in bytes, 0 = no limit
size_t limit_size = 0;
// in tokens, 0 = no limit
size_t limit_tokens = 0;
size_t size() const;
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
void update();
};

View File

@@ -0,0 +1,817 @@
#include "server-tools.h"
#include <sheredom/subprocess.h>
#include <filesystem>
#include <fstream>
#include <regex>
#include <thread>
#include <chrono>
#include <atomic>
#include <cstring>
#include <climits>
#include <algorithm>
namespace fs = std::filesystem;
//
// internal helpers
//
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
std::vector<char *> r;
r.reserve(v.size() + 1);
for (const auto & s : v) {
r.push_back(const_cast<char *>(s.c_str()));
}
r.push_back(nullptr);
return r;
}
struct run_proc_result {
std::string output;
int exit_code = -1;
bool timed_out = false;
};
static run_proc_result run_process(
const std::vector<std::string> & args,
size_t max_output,
int timeout_secs) {
run_proc_result res;
subprocess_s proc;
auto argv = to_cstr_vec(args);
int options = subprocess_option_no_window
| subprocess_option_combined_stdout_stderr
| subprocess_option_inherit_environment
| subprocess_option_search_user_path;
if (subprocess_create(argv.data(), options, &proc) != 0) {
res.output = "failed to spawn process";
return res;
}
std::atomic<bool> done{false};
std::atomic<bool> timed_out{false};
std::thread timeout_thread([&]() {
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
while (!done.load()) {
if (std::chrono::steady_clock::now() >= deadline) {
timed_out.store(true);
subprocess_terminate(&proc);
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
});
FILE * f = subprocess_stdout(&proc);
std::string output;
bool truncated = false;
if (f) {
char buf[4096];
while (fgets(buf, sizeof(buf), f) != nullptr) {
if (!truncated) {
size_t len = strlen(buf);
if (output.size() + len <= max_output) {
output.append(buf, len);
} else {
output.append(buf, max_output - output.size());
truncated = true;
}
}
}
}
done.store(true);
if (timeout_thread.joinable()) {
timeout_thread.join();
}
subprocess_join(&proc, &res.exit_code);
subprocess_destroy(&proc);
res.output = output;
res.timed_out = timed_out.load();
if (truncated) {
res.output += "\n[output truncated]";
}
return res;
}
json server_tool::to_json() {
return {
{"display_name", display_name},
{"tool", name},
{"type", "builtin"},
{"permissions", json{
{"write", permission_write}
}},
{"definition", get_definition()},
};
}
//
// read_file: read a file with optional line range and line-number prefix
//
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
struct server_tool_read_file : server_tool {
server_tool_read_file() {
name = "read_file";
display_name = "Read file";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
}},
{"required", json::array({"path"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
int start_line = json_value(params, "start_line", 1);
int end_line = json_value(params, "end_line", -1); // -1 = no limit
bool append_loc = json_value(params, "append_loc", false);
std::error_code ec;
uintmax_t file_size = fs::file_size(path, ec);
if (ec) {
return {{"error", "cannot stat file: " + ec.message()}};
}
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
return {{"error", string_format(
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
}
std::ifstream f(path);
if (!f) {
return {{"error", "failed to open file: " + path}};
}
std::string result;
std::string line;
int lineno = 0;
while (std::getline(f, line)) {
lineno++;
if (lineno < start_line) continue;
if (end_line != -1 && lineno > end_line) break;
std::string out_line;
if (append_loc) {
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
} else {
out_line = line + "\n";
}
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
result += "[output truncated]";
break;
}
result += out_line;
}
return {{"plain_text_response", result}};
}
};
//
// file_glob_search: find files matching a glob pattern under a base directory
//
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
struct server_tool_file_glob_search : server_tool {
server_tool_file_glob_search() {
name = "file_glob_search";
display_name = "File search";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Recursively search for files matching a glob pattern under a directory."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
}},
{"required", json::array({"path"})},
}},
}},
};
}
json invoke(json params) override {
std::string base = params.at("path").get<std::string>();
std::string include = json_value(params, "include", std::string("**"));
std::string exclude = json_value(params, "exclude", std::string(""));
std::ostringstream output_text;
size_t count = 0;
std::error_code ec;
for (const auto & entry : fs::recursive_directory_iterator(base,
fs::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) continue;
std::string rel = fs::relative(entry.path(), base, ec).string();
if (ec) continue;
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(include, rel)) continue;
if (!exclude.empty() && glob_match(exclude, rel)) continue;
output_text << entry.path().string() << "\n";
if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
break;
}
}
output_text << "\n---\nTotal matches: " << count << "\n";
return {{"plain_text_response", output_text.str()}};
}
};
//
// grep_search: search for a regex pattern in files
//
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
struct server_tool_grep_search : server_tool {
server_tool_grep_search() {
name = "grep_search";
display_name = "Grep search";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
}},
{"required", json::array({"path", "pattern"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
std::string pat_str = params.at("pattern").get<std::string>();
std::string include = json_value(params, "include", std::string("**"));
std::string exclude = json_value(params, "exclude", std::string(""));
bool show_lineno = json_value(params, "return_line_numbers", false);
std::regex pattern;
try {
pattern = std::regex(pat_str);
} catch (const std::regex_error & e) {
return {{"error", std::string("invalid regex: ") + e.what()}};
}
std::ostringstream output_text;
size_t total = 0;
auto search_file = [&](const fs::path & fpath) {
std::ifstream f(fpath);
if (!f) return;
std::string line;
int lineno = 0;
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
lineno++;
if (std::regex_search(line, pattern)) {
output_text << fpath.string() << ":";
if (show_lineno) {
output_text << lineno << ":";
}
output_text << line << "\n";
total++;
}
}
};
std::error_code ec;
if (fs::is_regular_file(path, ec)) {
search_file(path);
} else if (fs::is_directory(path, ec)) {
for (const auto & entry : fs::recursive_directory_iterator(path,
fs::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) continue;
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
std::string rel = fs::relative(entry.path(), path, ec).string();
if (ec) continue;
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(include, rel)) continue;
if (!exclude.empty() && glob_match(exclude, rel)) continue;
search_file(entry.path());
}
} else {
return {{"error", "path does not exist: " + path}};
}
output_text << "\n\n---\nTotal matches: " << total << "\n";
return {{"plain_text_response", output_text.str()}};
}
};
//
// exec_shell_command: run an arbitrary shell command
//
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
struct server_tool_exec_shell_command : server_tool {
server_tool_exec_shell_command() {
name = "exec_shell_command";
display_name = "Execute shell command";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
{"parameters", {
{"type", "object"},
{"properties", {
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
}},
{"required", json::array({"command"})},
}},
}},
};
}
json invoke(json params) override {
std::string command = params.at("command").get<std::string>();
int timeout = json_value(params, "timeout", 10);
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
#ifdef _WIN32
std::vector<std::string> args = {"cmd", "/c", command};
#else
std::vector<std::string> args = {"sh", "-c", command};
#endif
auto res = run_process(args, max_output, timeout);
std::string text_output = res.output;
text_output += string_format("\n[exit code: %d]", res.exit_code);
if (res.timed_out) {
text_output += " [exit due to timed out]";
}
return {{"plain_text_response", text_output}};
}
};
//
// write_file: create or overwrite a file
//
struct server_tool_write_file : server_tool {
server_tool_write_file() {
name = "write_file";
display_name = "Write file";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
{"content", {{"type", "string"}, {"description", "Content to write"}}},
}},
{"required", json::array({"path", "content"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
std::string content = params.at("content").get<std::string>();
std::error_code ec;
fs::path fpath(path);
if (fpath.has_parent_path()) {
fs::create_directories(fpath.parent_path(), ec);
if (ec) {
return {{"error", "failed to create directories: " + ec.message()}};
}
}
std::ofstream f(path, std::ios::binary);
if (!f) {
return {{"error", "failed to open file for writing: " + path}};
}
f << content;
if (!f) {
return {{"error", "failed to write file: " + path}};
}
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
}
};
//
// edit_file: edit file content via line-based changes
//
struct server_tool_edit_file : server_tool {
server_tool_edit_file() {
name = "edit_file";
display_name = "Edit file";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description",
"Edit a file by applying a list of line-based changes. "
"Each change targets a 1-based inclusive line range and has a mode: "
"\"replace\" (replace lines with content), "
"\"delete\" (remove lines, content must be empty string), "
"\"append\" (insert content after line_end). "
"Set line_start to -1 to target the end of file (line_end is ignored in that case). "
"Changes must not overlap. They are applied in reverse line order automatically."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path to the file to edit"}}},
{"changes", {
{"type", "array"},
{"description", "List of changes to apply"},
{"items", {
{"type", "object"},
{"properties", {
{"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}},
{"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}},
{"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}},
{"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}},
}},
{"required", json::array({"mode", "line_start", "line_end", "content"})},
}},
}},
}},
{"required", json::array({"path", "changes"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
const json & changes = params.at("changes");
if (!changes.is_array()) {
return {{"error", "\"changes\" must be an array"}};
}
// read file into lines
std::ifstream fin(path);
if (!fin) {
return {{"error", "failed to open file: " + path}};
}
std::vector<std::string> lines;
{
std::string line;
while (std::getline(fin, line)) {
lines.push_back(line);
}
}
fin.close();
// validate and collect changes, then sort descending by line_start
struct change_entry {
std::string mode;
int line_start; // 1-based
int line_end; // 1-based inclusive
std::string content;
};
std::vector<change_entry> entries;
entries.reserve(changes.size());
for (const auto & ch : changes) {
change_entry e;
e.mode = ch.at("mode").get<std::string>();
e.line_start = ch.at("line_start").get<int>();
e.line_end = ch.at("line_end").get<int>();
e.content = ch.at("content").get<std::string>();
if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") {
return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}};
}
if (e.mode == "delete" && !e.content.empty()) {
return {{"error", "content must be empty string for delete mode"}};
}
int n = (int) lines.size();
if (e.line_start == -1) {
// -1 means end of file; line_end is ignored — normalize to point past last line
e.line_start = n + 1;
e.line_end = n + 1;
} else {
if (e.line_start < 1 || e.line_end < e.line_start) {
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
}
if (e.line_end > n) {
return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}};
}
}
entries.push_back(std::move(e));
}
// sort descending so earlier-indexed changes don't shift later ones
std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) {
return a.line_start > b.line_start;
});
// apply changes (0-based indices internally)
for (const auto & e : entries) {
int idx_start = e.line_start - 1; // 0-based
int idx_end = e.line_end - 1; // 0-based inclusive
// split content into lines (preserve trailing newline awareness)
std::vector<std::string> new_lines;
if (!e.content.empty()) {
std::istringstream ss(e.content);
std::string ln;
while (std::getline(ss, ln)) {
new_lines.push_back(ln);
}
// if content ends with \n, getline consumed it — no extra empty line needed
// if content does NOT end with \n, last line is still captured correctly
}
if (e.mode == "replace") {
// erase [idx_start, idx_end] and insert new_lines
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end());
} else if (e.mode == "delete") {
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
} else { // append
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
}
}
// write file back
std::ofstream fout(path, std::ios::binary);
if (!fout) {
return {{"error", "failed to open file for writing: " + path}};
}
for (size_t i = 0; i < lines.size(); i++) {
fout << lines[i];
if (i + 1 < lines.size()) {
fout << "\n";
}
}
if (!lines.empty()) {
fout << "\n";
}
if (!fout) {
return {{"error", "failed to write file: " + path}};
}
return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}};
}
};
//
// apply_diff: apply a unified diff via git apply
//
struct server_tool_apply_diff : server_tool {
server_tool_apply_diff() {
name = "apply_diff";
display_name = "Apply diff";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."},
{"parameters", {
{"type", "object"},
{"properties", {
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
}},
{"required", json::array({"diff"})},
}},
}},
};
}
json invoke(json params) override {
std::string diff = params.at("diff").get<std::string>();
// write diff to a temporary file
static std::atomic<int> counter{0};
std::string tmp_path = (fs::temp_directory_path() /
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
{
std::ofstream f(tmp_path, std::ios::binary);
if (!f) {
return {{"error", "failed to create temp patch file"}};
}
f << diff;
}
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
std::error_code ec;
fs::remove(tmp_path, ec);
if (res.exit_code != 0) {
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
}
return {{"result", "patch applied successfully"}};
}
};
//
// get_datetime: returns the current date and time
//
struct server_tool_get_datetime : server_tool {
server_tool_get_datetime() {
name = "get_datetime";
display_name = "Get Date & Time";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Returns the current date and time"},
}},
};
}
json invoke(json) override {
auto now = std::chrono::system_clock::now();
auto time = std::chrono::system_clock::to_time_t(now);
return {{"result", std::ctime(&time)}};
}
};
//
// public API
//
static std::vector<std::unique_ptr<server_tool>> build_tools() {
std::vector<std::unique_ptr<server_tool>> tools;
tools.push_back(std::make_unique<server_tool_read_file>());
tools.push_back(std::make_unique<server_tool_file_glob_search>());
tools.push_back(std::make_unique<server_tool_grep_search>());
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
tools.push_back(std::make_unique<server_tool_write_file>());
tools.push_back(std::make_unique<server_tool_edit_file>());
tools.push_back(std::make_unique<server_tool_apply_diff>());
tools.push_back(std::make_unique<server_tool_get_datetime>());
return tools;
}
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
if (!enabled_tools.empty()) {
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
auto all_tools = build_tools();
// collect all known tool names for validation
std::vector<std::string> known_names;
known_names.reserve(all_tools.size());
for (const auto & t : all_tools) {
known_names.push_back(t->name);
}
// validate that every requested tool is known
for (const auto & name : enabled_tools) {
if (name == "all") continue;
if (std::find(known_names.begin(), known_names.end(), name) == known_names.end()) {
throw std::runtime_error(string_format(
"unknown tool \"%s\". available tools: %s",
name.c_str(),
string_join(known_names, ", ").c_str()));
}
}
tools.clear();
for (auto & t : all_tools) {
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
tools.push_back(std::move(t));
}
}
}
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json result = json::array();
for (const auto & t : tools) {
result.push_back(t->to_json());
}
res->data = safe_json_to_str(result);
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json body = json::parse(req.body);
std::string tool_name = body.at("tool").get<std::string>();
json params = body.value("params", json::object());
json result = invoke(tool_name, params);
res->data = safe_json_to_str(result);
} catch (const json::exception & e) {
res->status = 400;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
}
json server_tools::invoke(const std::string & name, const json & params) {
for (auto & t : tools) {
if (t->name == name) {
return t->invoke(params);
}
}
return {{"error", "unknown tool: " + name}};
}

View File

@@ -0,0 +1,26 @@
#pragma once
#include "server-common.h"
#include "server-http.h"
struct server_tool {
std::string name;
std::string display_name;
bool permission_write = false;
virtual ~server_tool() = default;
virtual json get_definition() = 0;
virtual json invoke(json params) = 0;
json to_json();
};
struct server_tools {
std::vector<std::unique_ptr<server_tool>> tools;
void setup(const std::vector<std::string> & enabled_tools);
json invoke(const std::string & name, const json & params);
server_http_context::handler_t handle_get;
server_http_context::handler_t handle_post;
};

363
tools/server/server.cpp Normal file
View File

@@ -0,0 +1,363 @@
#include "server-context.h"
#include "server-http.h"
#include "server-models.h"
#include "server-cors-proxy.h"
#include "server-tools.h"
#include "arg.h"
#include "build-info.h"
#include "common.h"
#include "fit.h"
#include "llama.h"
#include "log.h"
#include <atomic>
#include <clocale>
#include <exception>
#include <signal.h>
#include <thread> // for std::thread::hardware_concurrency
#if defined(_WIN32)
#include <windows.h>
#endif
static std::function<void(int)> shutdown_handler;
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
static inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
// this is for better developer experience, we can remove when the server is stable enough
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
exit(1);
}
shutdown_handler(signal);
}
// wrapper function that handles exceptions and logs errors
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
std::string message;
error_type error;
try {
return func(req);
} catch (const std::invalid_argument & e) {
// treat invalid_argument as invalid request (400)
error = ERROR_TYPE_INVALID_REQUEST;
message = e.what();
} catch (const std::exception & e) {
// treat other exceptions as server error (500)
error = ERROR_TYPE_SERVER;
message = e.what();
} catch (...) {
error = ERROR_TYPE_SERVER;
message = "unknown error";
}
auto res = std::make_unique<server_http_res>();
res->status = 500;
try {
json error_data = format_error_response(message, error);
res->status = json_value(error_data, "code", 500);
res->data = safe_json_to_str({{ "error", error_data }});
SRV_WRN("got exception: %s\n", res->data.c_str());
} catch (const std::exception & e) {
SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
res->data = "Internal Server Error";
}
return res;
};
}
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
// own arguments required by this example
common_params params;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
return 1;
}
// validate batch size for embeddings
// embeddings require all tokens to be processed in a single ubatch
// see https://github.com/ggml-org/llama.cpp/issues/12836
if (params.embedding && params.n_batch > params.n_ubatch) {
LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
params.n_batch = params.n_ubatch;
}
if (params.n_parallel < 0) {
LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
params.n_parallel = 4;
params.kv_unified = true;
}
// for consistency between server router mode and single-model mode, we set the same model name as alias
if (params.model_alias.empty() && !params.model.name.empty()) {
params.model_alias.insert(params.model.name);
}
// struct that contains llama context and inference
server_context ctx_server;
llama_backend_init();
llama_numa_init(params.numa);
LOG_INF("build_info: %s\n", llama_build_info());
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
server_http_context ctx_http;
if (!ctx_http.init(params)) {
LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
return 1;
}
//
// Router
//
// register API routes
server_routes routes(params, ctx_server);
server_tools tools;
bool is_router_server = params.model.path.empty();
std::optional<server_models_routes> models_routes{};
if (is_router_server) {
// setup server instances manager
try {
models_routes.emplace(params, argc, argv);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what());
return 1;
}
// proxy handlers
// note: routes.get_health stays the same
routes.get_metrics = models_routes->proxy_get;
routes.post_props = models_routes->proxy_post;
routes.post_completions = models_routes->proxy_post;
routes.post_completions_oai = models_routes->proxy_post;
routes.post_chat_completions = models_routes->proxy_post;
routes.post_responses_oai = models_routes->proxy_post;
routes.post_transcriptions_oai = models_routes->proxy_post;
routes.post_anthropic_messages = models_routes->proxy_post;
routes.post_anthropic_count_tokens = models_routes->proxy_post;
routes.post_infill = models_routes->proxy_post;
routes.post_embeddings = models_routes->proxy_post;
routes.post_embeddings_oai = models_routes->proxy_post;
routes.post_rerank = models_routes->proxy_post;
routes.post_tokenize = models_routes->proxy_post;
routes.post_detokenize = models_routes->proxy_post;
routes.post_apply_template = models_routes->proxy_post;
routes.get_lora_adapters = models_routes->proxy_get;
routes.post_lora_adapters = models_routes->proxy_post;
routes.get_slots = models_routes->proxy_get;
routes.post_slots = models_routes->proxy_post;
// custom routes for router
routes.get_props = models_routes->get_router_props;
routes.get_models = models_routes->get_router_models;
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
}
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
ctx_http.get ("/props", ex_wrapper(routes.get_props));
ctx_http.post("/props", ex_wrapper(routes.post_props));
ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
ctx_http.post("/completions", ex_wrapper(routes.post_completions));
ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai));
ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
ctx_http.post("/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
// LoRA adapters hotswap
ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
// Save & load slots
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
// Google Cloud Platform (Vertex AI) compat
ctx_http.register_gcp_compat();
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
if (params.webui_mcp_proxy) {
SRV_WRN("%s", "-----------------\n");
SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n");
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n");
SRV_WRN("%s", "-----------------\n");
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
}
// EXPERIMENTAL built-in tools
if (!params.server_tools.empty()) {
try {
tools.setup(params.server_tools);
} catch (const std::exception & e) {
LOG_ERR("%s: tools setup failed: %s\n", __func__, e.what());
return 1;
}
SRV_WRN("%s", "-----------------\n");
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
SRV_WRN("%s", "-----------------\n");
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
}
//
// Start the server
//
std::function<void()> clean_up;
if (is_router_server) {
LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
clean_up = [&models_routes]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
if (models_routes.has_value()) {
models_routes->models.unload_all();
}
llama_backend_free();
};
if (!ctx_http.start()) {
clean_up();
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
ctx_http.is_ready.store(true);
shutdown_handler = [&](int) {
ctx_http.stop();
};
} else {
// setup clean up function, to be called before exit
clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
ctx_http.stop();
ctx_server.terminate();
llama_backend_free();
};
// start the HTTP server before loading the model to be able to serve /health requests
if (!ctx_http.start()) {
clean_up();
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
// load the model
LOG_INF("%s: loading model\n", __func__);
if (server_models::is_child_server()) {
ctx_server.on_sleeping_changed([&](bool sleeping) {
server_models::notify_router_sleeping_state(sleeping);
});
}
if (!ctx_server.load_model(params)) {
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
}
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
routes.update_meta(ctx_server);
ctx_http.is_ready.store(true);
LOG_INF("%s: model loaded\n", __func__);
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.terminate();
};
}
// TODO: refactor in common/console
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
if (is_router_server) {
LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
LOG_INF("%s: NOTE: router mode is experimental\n", __func__);
LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__);
if (ctx_http.thread.joinable()) {
ctx_http.thread.join(); // keep the main thread alive
}
// when the HTTP server stops, clean up and exit
clean_up();
} else {
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
LOG_INF("%s: starting the main loop...\n", __func__);
// optionally, notify router server that this instance is ready
std::thread monitor_thread;
if (server_models::is_child_server()) {
json model_info = routes.get_model_info();
monitor_thread = server_models::setup_child_server(shutdown_handler, model_info);
}
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.start_loop();
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
}
if (monitor_thread.joinable()) {
monitor_thread.join();
}
auto * ll_ctx = ctx_server.get_llama_context();
if (ll_ctx != nullptr) {
common_memory_breakdown_print(ll_ctx);
}
}
return 0;
}

2
tools/server/tests/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.venv
tmp

View File

@@ -0,0 +1,96 @@
# Server tests
Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/).
Tests target GitHub workflows job runners with 4 vCPU.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
To mitigate it, you can increase values in `n_predict`, `kv_size`.
### Install dependencies
`pip install -r requirements.txt`
### Run tests
1. Build the server
```shell
cd ../../..
cmake -B build
cmake --build build --target llama-server
```
2. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables:
| variable | description |
|--------------------------|------------------------------------------------------------------------------------------------|
| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` |
| `DEBUG` | to enable steps and server verbose mode `--verbose` |
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this |
To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed):
```shell
SLOW_TESTS=1 ./tests.sh
```
To run with stdout/stderr display in real time (verbose output, but useful for debugging):
```shell
DEBUG=1 ./tests.sh -s -v -x
```
To run all the tests in a file:
```shell
./tests.sh unit/test_chat_completion.py -v -x
```
To run a single test:
```shell
./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req
```
Hint: You can compile and run test in single command, useful for local development:
```shell
cmake --build build -j --target llama-server && ./tools/server/tests/tests.sh
```
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)
### Debugging external llama-server
It can sometimes be useful to run the server in a debugger when invesigating test
failures. To do this, the environment variable `DEBUG_EXTERNAL=1` can be set
which will cause the test to skip starting a llama-server itself. Instead, the
server can be started in a debugger.
Example using `gdb`:
```console
$ gdb --args ../../../build/bin/llama-server \
--host 127.0.0.1 --port 8080 \
--temp 0.8 --seed 42 \
--hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf \
--batch-size 32 --no-slots --alias tinyllama-2 --ctx-size 512 \
--parallel 2 --n-predict 64
```
And a break point can be set in before running:
```console
(gdb) br server.cpp:4604
(gdb) r
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv update_slots: all slots are idle
```
And then the test in question can be run in another terminal:
```console
(venv) $ env DEBUG_EXTERNAL=1 ./tests.sh unit/test_chat_completion.py -v -x
```
And this should trigger the breakpoint and allow inspection of the server state
in the debugger terminal.

View File

@@ -0,0 +1,21 @@
import pytest
from utils import *
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
@pytest.fixture(autouse=True)
def stop_server_after_each_test():
# do nothing before each test
yield
# stop all servers after each test
instances = set(
server_instances
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()
@pytest.fixture(scope="module", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()

View File

@@ -0,0 +1,4 @@
[pytest]
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
serial

View File

@@ -0,0 +1,8 @@
aiohttp~=3.9.3
pytest~=8.3.3
huggingface_hub>=1.5.0,<2.0
numpy~=1.26.4
openai~=2.14.0
prometheus-client~=0.20.0
requests~=2.32.3
wget~=3.2

23
tools/server/tests/tests.sh Executable file
View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
set -eu
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
# Slow tests for tool calls need quite a few models ahead of time to avoid timing out.
python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py
fi
if [ $# -lt 1 ]
then
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
pytest -v -x
else
pytest -v -x -m "not slow"
fi
else
pytest "$@"
fi

View File

@@ -0,0 +1,113 @@
import pytest
import requests
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_start_simple():
global server
server.start()
res = server.make_request("GET", "/health")
assert res.status_code == 200
def test_server_props():
global server
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert ".gguf" in res.body["model_path"]
assert res.body["total_slots"] == server.n_slots
default_val = res.body["default_generation_settings"]
assert server.n_ctx is not None and server.n_slots is not None
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
assert default_val["params"]["seed"] == server.seed
def test_server_models():
global server
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias
def test_server_slots():
global server
# without slots endpoint enabled, this should return error
server.server_slots = False
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
assert "error" in res.body
server.stop()
# with slots endpoint enabled, this should return slots info
server.server_slots = True
server.n_slots = 2
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 200
assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" not in res.body[0]
def test_load_split_model():
global server
server.offline = False
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])
def test_no_webui():
global server
# default: webui enabled
server.start()
url = f"http://{server.server_host}:{server.server_port}"
res = requests.get(url)
assert res.status_code == 200
assert "<!doctype html>" in res.text
server.stop()
# with --no-webui
server.no_webui = True
server.start()
res = requests.get(url)
assert res.status_code == 404
def test_server_model_aliases_and_tags():
global server
server.model_alias = "tinyllama-2,fim,code"
server.model_tags = "chat,fim,small"
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
model = res.body["data"][0]
# aliases field must contain all aliases
assert set(model["aliases"]) == {"tinyllama-2", "fim", "code"}
# tags field must contain all tags
assert set(model["tags"]) == {"chat", "fim", "small"}
# id is derived from first alias (alphabetical order from std::set)
assert model["id"] == "code"

View File

@@ -0,0 +1,534 @@
import pytest
from openai import OpenAI
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
global server
server.jinja = jinja
server.chat_template = chat_template
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["system_fingerprint"].startswith("b")
# we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
# assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
def test_chat_completion_cached_tokens():
global server
server.n_slots = 1
server.start()
seq = [
("1 2 3 4 5 6", 77, 0),
("1 2 3 4 5 6", 77, 76),
("1 2 3 4 5 9", 77, 51),
("1 2 3 9 9 9", 77, 47),
]
for user_prompt, n_prompt, n_cache in seq:
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Test"},
{"role": "user", "content": user_prompt},
],
})
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
@pytest.mark.parametrize(
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
]
)
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server
server.model_alias = "llama-test-model"
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
last_cmpl_id = None
for i, data in enumerate(res):
if data["choices"]:
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert data["model"] == "llama-test-model"
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
else:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
def test_chat_template():
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@pytest.mark.parametrize("prefill,re_prefill", [
("Whill", "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
])
def test_chat_template_assistant_prefill(prefill, re_prefill):
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
{"role": "assistant", "content": prefill},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
def test_apply_chat_template():
global server
server.chat_template = "command-r"
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "system", "content": "You are a test."},
{"role": "user", "content":"Hi there"},
]
})
assert res.status_code == 200
assert "prompt" in res.body
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.debug = True
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body
def test_chat_completion_with_timings_per_token():
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [{"role": "user", "content": "test"}],
"stream": True,
"stream_options": {"include_usage": True},
"timings_per_token": True,
})
stats_received = False
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] is None
assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}'
else:
if data["choices"]:
assert "role" not in data["choices"][0]["delta"]
else:
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
stats_received = True
assert stats_received
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
)
output_text = res.choices[0].message.content
aggregated_text = ''
assert res.choices[0].logprobs is not None
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
stream=True,
)
output_text = ''
aggregated_text = ''
for i, data in enumerate(res):
if data.choices:
choice = data.choices[0]
if i == 0:
# Check first role message for stream=True
assert choice.delta.content is None
assert choice.delta.role == "assistant"
else:
assert choice.delta.role is None
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logit_bias():
global server
server.start()
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
res = server.make_request("POST", "/tokenize", data={
"content": " " + " ".join(exclude) + " ",
})
assert res.status_code == 200
tokens = res.body["tokens"]
logit_bias = {tok: -100 for tok in tokens}
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=64,
logit_bias=logit_bias
)
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
def test_context_size_exceeded():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
})
assert res.status_code == 400
assert "error" in res.body
assert res.body["error"]["type"] == "exceed_context_size_error"
assert res.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
def test_context_size_exceeded_stream():
global server
server.start()
try:
for _ in server.make_stream_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
"stream": True}):
pass
assert False, "Should have failed"
except ServerError as e:
assert e.code == 400
assert "error" in e.body
assert e.body["error"]["type"] == "exceed_context_size_error"
assert e.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 4, False),
(64, 2, True),
]
)
def test_return_progress(n_batch, batch_count, reuse_cache):
global server
server.n_batch = n_batch
server.n_ctx = 256
server.n_slots = 1
server.start()
def make_cmpl_request():
return server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [
{"role": "user", "content": "This is a test" * 10},
],
"stream": True,
"return_progress": True,
})
if reuse_cache:
# make a first request to populate the cache
res0 = make_cmpl_request()
for _ in res0:
pass # discard the output
res = make_cmpl_request()
last_progress = None
total_batch_count = 0
for data in res:
cur_progress = data.get("prompt_progress", None)
if cur_progress is None:
continue
if total_batch_count == 0:
# first progress report must have n_cache == n_processed
assert cur_progress["total"] > 0
assert cur_progress["cache"] == cur_progress["processed"]
if reuse_cache:
# when reusing cache, we expect some cached tokens
assert cur_progress["cache"] > 0
if last_progress is not None:
assert cur_progress["total"] == last_progress["total"]
assert cur_progress["cache"] == last_progress["cache"]
assert cur_progress["processed"] > last_progress["processed"]
total_batch_count += 1
last_progress = cur_progress
# last progress should indicate completion (all tokens processed)
assert last_progress is not None
assert last_progress["total"] > 0
assert last_progress["processed"] == last_progress["total"]
assert total_batch_count == batch_count
def test_chat_completions_multiple_choices():
global server
server.start()
# make sure cache can be reused across multiple choices and multiple requests
# ref: https://github.com/ggml-org/llama.cpp/pull/18663
for _ in range(2):
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"n": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
# test forcing the same slot to be used
# the scheduler should not be locked up in this case
"id_slot": 0,
})
assert res.status_code == 200
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert choice["finish_reason"] == "length"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
import pytest
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.gcp_compat = True
def test_gcp_predict_camel_case():
global server
server.start()
res = server.make_request("POST", "/predict", data={
"instances": [
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [
{"role": "user", "content": "What is the meaning of life?"},
],
}
],
})
assert res.status_code == 200
assert "predictions" in res.body
assert len(res.body["predictions"]) == 1
prediction = res.body["predictions"][0]
assert "choices" in prediction
assert len(prediction["choices"]) == 1
assert prediction["choices"][0]["message"]["role"] == "assistant"
assert len(prediction["choices"][0]["message"]["content"]) > 0
def test_gcp_predict_multiple_instances():
global server
server.n_slots = 2
server.start()
res = server.make_request("POST", "/predict", data={
"instances": [
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [{"role": "user", "content": "Say hello"}],
},
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [{"role": "user", "content": "Say world"}],
},
],
})
assert res.status_code == 200
assert len(res.body["predictions"]) == 2
for prediction in res.body["predictions"]:
assert "choices" in prediction
assert len(prediction["choices"][0]["message"]["content"]) > 0

View File

@@ -0,0 +1,73 @@
import pytest
from openai import OpenAI
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_responses_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.responses.create(
model="gpt-4.1",
input=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_output_tokens=8,
temperature=0.8,
)
assert res.id.startswith("resp_")
assert res.output[0].id is not None
assert res.output[0].id.startswith("msg_")
assert match_regex("(Suddenly)+", res.output_text)
def test_responses_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
stream = client.responses.create(
model="gpt-4.1",
input=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_output_tokens=8,
temperature=0.8,
stream=True,
)
gathered_text = ''
resp_id = ''
msg_id = ''
for r in stream:
if r.type == "response.created":
assert r.response.id.startswith("resp_")
resp_id = r.response.id
if r.type == "response.in_progress":
assert r.response.id == resp_id
if r.type == "response.output_item.added":
assert r.item.id is not None
assert r.item.id.startswith("msg_")
msg_id = r.item.id
if (r.type == "response.content_part.added" or
r.type == "response.output_text.delta" or
r.type == "response.output_text.done" or
r.type == "response.content_part.done"):
assert r.item_id == msg_id
if r.type == "response.output_item.done":
assert r.item.id == msg_id
if r.type == "response.output_text.delta":
gathered_text += r.delta
if r.type == "response.completed":
assert r.response.id.startswith("resp_")
assert r.response.output[0].id is not None
assert r.response.output[0].id.startswith("msg_")
assert gathered_text == r.response.output_text
assert match_regex("(Suddenly)+", r.response.output_text)

View File

@@ -0,0 +1,661 @@
import pytest
import requests
import time
import random
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
JSON_MULTIMODAL_KEY = "multimodal_data"
JSON_PROMPT_STRING_KEY = "prompt_string"
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"return_tokens": return_tokens,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert type(res.body["has_new_line"]) == bool
assert match_regex(re_content, res.body["content"])
if return_tokens:
assert len(res.body["tokens"]) > 0
assert all(type(tok) == int for tok in res.body["tokens"])
else:
assert res.body["tokens"] == []
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"stream": True,
})
content = ""
for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert type(data["has_new_line"]) == bool
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content)
else:
assert len(data["tokens"]) > 0
assert all(type(tok) == int for tok in data["tokens"])
content += data["content"]
def test_completion_stream_vs_non_stream():
global server
server.start()
res_stream = server.make_stream_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
"stream": True,
})
res_non_stream = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
})
content_stream = ""
for data in res_stream:
content_stream += data["content"]
assert content_stream == res_non_stream.body["content"]
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].text is not None
assert match_regex("(going|bed)+", res.choices[0].text)
def test_completion_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("(going|bed)+", output_text)
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
@pytest.mark.slow
def test_completion_stream_with_openai_library_stops():
global server
server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
server.model_hf_file = None
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="System: You are helpful assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
stop=["User:\n", "Assistant:\n"],
max_tokens=200,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_slots", [1, 2])
def test_different_result_different_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for seed in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
last_res = res
# TODO figure why it don't work with temperature = 1
# @pytest.mark.parametrize("temperature", [0.0, 1.0])
@pytest.mark.parametrize("n_batch", [16, 32])
@pytest.mark.parametrize("temperature", [0.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
def test_cache_vs_nocache_prompt():
global server
server.start()
res_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]
def test_nocache_long_input_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is"*32,
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 400
def test_json_prompt_no_mtmd():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200
def test_json_prompt_mtm_error_when_not_supported():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
# MTMD is disabled on this model, so this should fail.
assert res.status_code != 200
def test_completion_with_tokens_input():
global server
server.temperature = 0.0
server.start()
prompt_str = "I believe the meaning of life is"
res = server.make_request("POST", "/tokenize", data={
"content": prompt_str,
"add_special": True,
})
assert res.status_code == 200
tokens = res.body["tokens"]
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed JSON and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [
tokens,
{
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
},
],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 3),
(2, 2),
(2, 4),
(4, 2), # some slots must be idle
(4, 6),
])
def test_completion_parallel_slots(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.temperature = 0.0
server.start()
PROMPTS = [
("Write a very long book.", "(very|special|big)+"),
("Write another a poem.", "(small|house)+"),
("What is LLM?", "(Dad|said)+"),
("The sky is blue and I love it.", "(climb|leaf)+"),
("Write another very long music lyrics.", "(friends|step|sky)+"),
("Write a very long joke.", "(cat|Whiskers)+"),
]
def check_slots_status():
should_all_slots_busy = n_requests >= n_slots
time.sleep(0.1)
res = server.make_request("GET", "/slots")
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
if should_all_slots_busy:
assert n_busy == n_slots
else:
assert n_busy <= n_slots
tasks = []
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
# check results
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
res = results[i]
assert res.status_code == 200
assert type(res.body["content"]) == str
assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize(
"n_ctx,n_slots,n_predict_vals,expected_success",
[
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
],
)
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
global server
server.n_slots = n_slots
server.kv_unified = True
server.n_ctx = n_ctx
server.start()
prompt = "A"
tasks = []
for n_predict in n_predict_vals:
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
results = parallel_function_calls(tasks)
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
if expect_ok:
assert res.status_code == 200
# note: https://github.com/ggml-org/llama.cpp/pull/18700#issuecomment-3728695581
if res.status_code == 200:
assert "content" in res.body
if "timings" in res.body:
assert res.body["timings"]["predicted_n"] == n_predict
@pytest.mark.parametrize(
"prompt,n_predict,response_fields",
[
("I believe the meaning of life is", 8, []),
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
],
)
def test_completion_response_fields(
prompt: str, n_predict: int, response_fields: list[str]
):
global server
server.start()
res = server.make_request(
"POST",
"/completion",
data={
"n_predict": n_predict,
"prompt": prompt,
"response_fields": response_fields,
},
)
assert res.status_code == 200
assert "content" in res.body
assert len(res.body["content"])
if len(response_fields):
assert res.body["generation_settings/n_predict"] == n_predict
assert res.body["prompt"] == "<s> " + prompt
assert isinstance(res.body["content"], str)
assert len(res.body) == len(response_fields)
else:
assert len(res.body)
assert "generation_settings" in res.body
def test_n_probs():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_stream():
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"stream": True,
})
for data in res:
if data["stop"] == False:
assert "completion_probabilities" in data
assert len(data["completion_probabilities"]) == 1
for tok in data["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Today was the day. Today I would finally become a",
"n_probs": 10,
"temperature": 1.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for (i, tok) in enumerate(res.body["completion_probabilities"]):
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert "top_probs" in tok and type(tok["top_probs"]) == list
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
# 0.0 probability tokens should never be returned by the server
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
if i == 0:
# The prompt is vague enough that we should get at least 10 possibilities
# for the first token.
assert len(tok["top_probs"]) == 10
if len(tok["top_probs"]) < 10:
# Getting less than the requested number of probabilities should only happen
# if the ones we did get already sum to 1.0.
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
def test_n_probs_post_backend_sampling():
"""Verify that the same probabilities are returned with and without backend sampling."""
global server
server.backend_sampling = True
server.start()
def make_request(backend_sampling):
n_predict = 20
res = server.make_request("POST", "/completion", data={
"prompt": "The countries of Europe, in random order, are:",
"n_probs": 10,
"n_predict": n_predict,
"post_sampling_probs": True,
"seed": 4242,
"backend_sampling": backend_sampling,
})
assert res.status_code == 200
total_probs = 0
completions = res.body["completion_probabilities"]
assert len(completions) == n_predict
for tok in completions:
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
# data.
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
total_probs += len(tok["top_probs"])
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
assert total_probs >= 2 * n_predict
return completions
def verify_token(a, b):
assert a["id"] == b["id"]
assert a["token"] == b["token"]
assert a["bytes"] == b["bytes"]
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
for (a, b) in zip(make_request(True), make_request(False)):
verify_token(a, b)
assert len(a["top_probs"]) == len(b["top_probs"])
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
verify_token(aa, bb)
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
def test_logit_bias(tokenize, openai_style):
global server
server.start()
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
logit_bias = []
if tokenize:
res = server.make_request("POST", "/tokenize", data={
"content": " " + " ".join(exclude) + " ",
})
assert res.status_code == 200
tokens = res.body["tokens"]
logit_bias = [[tok, -100] for tok in tokens]
else:
logit_bias = [[" " + tok + " ", -100] for tok in exclude]
if openai_style:
logit_bias = {el[0]: -100 for el in logit_bias}
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": "What is the best book",
"logit_bias": logit_bias,
"temperature": 0.0
})
assert res.status_code == 200
output_text = res.body["content"]
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
def test_cancel_request():
global server
server.n_ctx = 4096
server.n_predict = -1
server.n_slots = 1
server.server_slots = True
server.start()
# send a request that will take a long time, but cancel it before it finishes
try:
server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
}, timeout=0.1)
except requests.exceptions.ReadTimeout:
pass # expected
# make sure the slot is free
time.sleep(2)
res = server.make_request("GET", "/slots")
assert res.body[0]["is_processing"] == False
# this test exercises the host-memory prompt cache
# ref: https://github.com/ggml-org/llama.cpp/pull/16391
# ref: https://github.com/ggml-org/llama.cpp/pull/17078
def test_completion_prompt_cache():
global server
server.n_slots = 2
server.kv_unified = True
server.start()
for _ in range(16):
# generate alternating random prompts with variable lengths in order to get them in and out of the cache
r = random.randint(0, 4)
prompt = (" Hello " + str(r)) * (40 + r)
n_prompt = (40 + r)*5 + 2
n_predict = random.randint(1, 8)
res = server.make_request(
"POST",
"/completion",
data={
"prompt": prompt,
"n_predict": n_predict,
},
)
assert res.status_code == 200
assert "content" in res.body
content = res.body["content"]
assert isinstance(content, str)
assert len(content) > 0
assert type(res.body["has_new_line"]) == bool
assert "timings" in res.body
timings = res.body["timings"]
assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
assert "predicted_n" in timings and timings["predicted_n"] == n_predict
assert "tokens" in res.body and isinstance(res.body["tokens"], list)

View File

@@ -0,0 +1,89 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
SHORT_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
""".strip()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_ctx = 512
server.n_slots = 2
server.n_predict = 128
def test_ctx_shift_enabled():
# the prompt is 226 tokens
# the slot context is 512/2 = 256 tokens
# 96 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 96,
"prompt": SHORT_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 226
assert res.body["timings"]["predicted_n"] == 96
assert res.body["truncated"] is True
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
assert res.body["truncated"] == truncated
def test_ctx_shift_disabled_long_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code != 200
assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"]
def test_ctx_shift_disabled_stream():
global server
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256,
"prompt": "Once",
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] == "length":
assert len(content) > 0
else:
assert choice["finish_reason"] is None
content += choice["text"]

View File

@@ -0,0 +1,291 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.bert_bge_small()
EPSILON = 1e-3
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
def test_embedding_single():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_multiple():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_multiple_with_fa():
server = ServerPreset.bert_bge_small_with_fa()
server.pooling = 'last'
server.start()
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"a "*253,
"b "*254,
"c "*255,
"d "*256,
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
@pytest.mark.parametrize(
"input,is_multi_prompt",
[
# do not crash on empty input
("", False),
# single prompt
("string", False),
([12, 34, 56], False),
([12, 34, "string", 56, 78], False),
# multiple prompts
(["string1", "string2"], True),
(["string1", [12, 34, 56]], True),
([[12, 34, 56], [12, 34, 56]], True),
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
data = res.body['data']
if is_multi_prompt:
assert len(data) == len(input)
for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in data[0]
assert len(data[0]['embedding']) > 1
def test_embedding_pooling_mean():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_pooling_mean_multiple():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI",
"This is a test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 3
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_pooling_none():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello",
})
assert res.status_code == 200
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
def test_embedding_pooling_none_oai():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})
# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400
assert "error" in res.body
def test_embedding_openai_library_single():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
])
assert len(res.data) == 4
for d in res.data:
assert len(d.embedding) > 1
def test_embedding_error_prompt_too_long():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]
def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 5
for i in range(1, len(res.body['data'])):
v0 = res.body['data'][0]['embedding']
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < EPSILON
@pytest.mark.parametrize(
"content,n_tokens",
[
("I believe the meaning of life is", 9),
("This is a test", 6),
]
)
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9
def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"
# get embedding in default format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input
})
assert res.status_code == 200
vec0 = res.body["data"][0]["embedding"]
# get embedding in base64 format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})
assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1
embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)
# Verify embedding is valid base64
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
assert len(floats) == len(vec0)
# make sure the decoded data is the same as the original
for x, y in zip(floats, vec0):
assert abs(x - y) < EPSILON

View File

@@ -0,0 +1,43 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_ignore_eos_populates_logit_bias():
"""ignore_eos=true must add EOG logit biases to generation_settings."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": True,
"temperature": 0.0,
})
assert res.status_code == 200
# EOG token biases must be present with -inf bias
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) > 0
for entry in logit_bias:
assert entry["bias"] is None # null in JSON represents -inf
def test_ignore_eos_false_no_logit_bias():
"""ignore_eos=false (default) must NOT add EOG logit biases."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": False,
"temperature": 0.0,
})
assert res.status_code == 200
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) == 0

View File

@@ -0,0 +1,77 @@
import pytest
from utils import *
server = ServerPreset.tinyllama_infill()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama_infill()
def test_infill_without_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
def test_infill_with_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [
{},
{"filename": "ok"},
{"filename": 123},
{"filename": 123, "text": "abc"},
{"filename": 123, "text": 456},
])
def test_invalid_input_extra_req(input_extra):
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_qwen_model():
global server
server.model_file = None
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
server.start(timeout_seconds=600)
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"

View File

@@ -0,0 +1,115 @@
import os
import tempfile
import pytest
from utils import *
server = ServerPreset.tinyllama2()
class LogReader:
def __init__(self, path):
self.path = path
self.pos = 0
def drain(self):
with open(self.path) as f:
f.seek(self.pos)
content = f.read()
self.pos = f.tell()
return content
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_slots = 2
server.n_predict = 4
server.temperature = 0.0
server.server_slots = True
server.cache_ram = 100
server.kv_unified = True
server.debug = True
fd, server.log_path = tempfile.mkstemp(suffix='.log')
os.close(fd)
yield
LONG_PROMPT = (
"Once upon a time in a land far away, there lived a brave knight "
"who traveled across mountains and rivers to find the legendary "
"golden sword hidden deep within the enchanted forest of whispers. "
"He met many creatures along the way including dragons and fairies "
"and wizards who helped him on his noble quest to save the kingdom."
)
# idle slot cleared on launch should restore from cache-ram
def test_clear_and_restore():
global server
server.start()
log = LogReader(server.log_path)
# verify feature is enabled
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" in log.drain()
res = server.make_request("POST", "/completion", data={
"prompt": LONG_PROMPT,
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
original_prompt_n = res.body["timings"]["prompt_n"]
# Slot 0 is the only slot with KV — should NOT be cleared
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
# Launching slot 1 clears idle slot 0
res = server.make_request("POST", "/completion", data={
"prompt": "The quick brown fox",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert "__TEST_TAG_CACHE_IDLE_SLOT__" in log.drain()
# Re-send same prompt — should restore from cache-ram
res = server.make_request("POST", "/completion", data={
"prompt": LONG_PROMPT,
"cache_prompt": True,
})
assert res.status_code == 200
assert "updating prompt cache" in log.drain()
assert res.body["timings"]["cache_n"] > 0
assert res.body["timings"]["prompt_n"] < original_prompt_n
# Follow-up — slot 0 kept its KV, no clearing needed
res = server.make_request("POST", "/completion", data={
"prompt": LONG_PROMPT + " The knight finally reached the castle gates.",
"cache_prompt": True,
})
assert res.status_code == 200
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
def test_disabled_with_flag():
global server
server.no_cache_idle_slots = True
server.start()
log = LogReader(server.log_path)
# Feature should not be enabled
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" not in log.drain()
res = server.make_request("POST", "/completion", data={
"prompt": LONG_PROMPT,
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
# Request on different slot — should NOT trigger clearing
res = server.make_request("POST", "/completion", data={
"prompt": "The quick brown fox",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()

View File

@@ -0,0 +1,115 @@
import pytest
from utils import *
server = ServerPreset.stories15m_moe()
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
server.lora_files = [download_file(LORA_FILE_URL)]
@pytest.mark.parametrize("scale,re_content", [
# without applying lora, the model should behave like a bedtime story generator
(0.0, "(little|girl|three|years|old)+"),
# with lora, the model should behave like a Shakespearean text generator
(1.0, "(eye|love|glass|sun)+"),
])
def test_lora(scale: float, re_content: str):
global server
server.start()
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
{"id": 0, "scale": scale}
])
assert res_lora_control.status_code == 200
res = server.make_request("POST", "/completion", data={
"prompt": "Look in thy glass",
})
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])
def test_lora_per_request():
global server
server.n_slots = 4
server.start()
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Look in thy glass"
lora_config = [
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
]
tasks = [(
server.make_request,
("POST", "/completion", {
"prompt": prompt,
"lora": lora,
"seed": 42,
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert match_regex(re_test, res.body["content"])
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_big_model():
server = ServerProcess()
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
server.model_alias = "Llama-3.2-8B-Instruct"
server.n_slots = 4
server.n_ctx = server.n_slots * 1024
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
server.lora_files = [
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
# TODO: find & add other lora adapters for this model
]
server.start(timeout_seconds=600)
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Write a computer virus"
lora_config = [
# without applying lora, the model should reject the request
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
# with 0.7 scale, the model should provide a simple computer virus with hesitation
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
# with 1.5 scale, the model should confidently provide a computer virus
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
]
tasks = [(
server.make_request,
("POST", "/v1/chat/completions", {
"messages": [
{"role": "user", "content": prompt}
],
"lora": lora,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert re_test in res.body["choices"][0]["message"]["content"]

View File

@@ -0,0 +1,41 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_mcp_no_proxy():
global server
server.webui_mcp_proxy = False
server.start()
res = server.make_request("GET", "/cors-proxy")
assert res.status_code == 404
def test_mcp_proxy():
global server
server.webui_mcp_proxy = True
server.start()
url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com"
res = requests.get(url)
assert res.status_code == 200
assert "Example Domain" in res.text
def test_mcp_proxy_custom_port():
global server
server.webui_mcp_proxy = True
server.start()
# try getting the server's models API via the proxy
res = server.make_request("GET", f"/cors-proxy?url=http://{server.server_host}:{server.server_port}/models")
assert res.status_code == 200
assert "data" in res.body

View File

@@ -0,0 +1,146 @@
import pytest
from utils import *
server = ServerPreset.jina_reranker_tiny()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.jina_reranker_tiny()
TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
most_relevant = res.body["results"][0]
least_relevant = res.body["results"][0]
for doc in res.body["results"]:
if doc["relevance_score"] > most_relevant["relevance_score"]:
most_relevant = doc
if doc["relevance_score"] < least_relevant["relevance_score"]:
least_relevant = doc
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4
most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc
assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [
[],
None,
123,
[1, 2, 3],
])
def test_invalid_rerank_req(documents):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": documents,
})
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.parametrize(
"query,doc1,doc2,n_tokens",
[
("Machine learning is", "A machine", "Learning is", 19),
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
]
)
def test_rerank_usage(query, doc1, doc2, n_tokens):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": query,
"documents": [
doc1,
doc2,
]
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
@pytest.mark.parametrize("top_n,expected_len", [
(None, len(TEST_DOCUMENTS)), # no top_n parameter
(2, 2),
(4, 4),
(99, len(TEST_DOCUMENTS)), # higher than available docs
])
def test_rerank_top_n(top_n, expected_len):
global server
server.start()
data = {
"query": "Machine learning is",
"documents": TEST_DOCUMENTS,
}
if top_n is not None:
data["top_n"] = top_n
res = server.make_request("POST", "/rerank", data=data)
assert res.status_code == 200
assert len(res.body["results"]) == expected_len
@pytest.mark.parametrize("top_n,expected_len", [
(None, len(TEST_DOCUMENTS)), # no top_n parameter
(2, 2),
(4, 4),
(99, len(TEST_DOCUMENTS)), # higher than available docs
])
def test_rerank_tei_top_n(top_n, expected_len):
global server
server.start()
data = {
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
}
if top_n is not None:
data["top_n"] = top_n
res = server.make_request("POST", "/rerank", data=data)
assert res.status_code == 200
assert len(res.body) == expected_len

View File

@@ -0,0 +1,255 @@
import pytest
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.router()
def test_router_props():
global server
server.models_max = 2
server.no_models_autoload = True
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["role"] == "router"
assert res.body["max_instances"] == 2
assert res.body["models_autoload"] is False
assert res.body["build_info"].startswith("b")
@pytest.mark.parametrize(
"model,success",
[
("ggml-org/tinygemma3-GGUF:Q8_0", True),
("non-existent/model", False),
]
)
def test_router_chat_completion_stream(model: str, success: bool):
global server
server.start()
content = ""
ex: ServerError | None = None
try:
res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": 16,
"messages": [
{"role": "user", "content": "hello"},
],
"stream": True,
})
for data in res:
if data["choices"]:
choice = data["choices"][0]
if choice["finish_reason"] in ["stop", "length"]:
assert "content" not in choice["delta"]
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
except ServerError as e:
ex = e
if success:
assert ex is None
assert len(content) > 0
else:
assert ex is not None
assert content == ""
def _get_model_ids(is_reload: bool) -> set[str]:
res = server.make_request("GET", "/models" + ("?reload=1" if is_reload else ""))
assert res.status_code == 200
return {item["id"] for item in res.body.get("data", [])}
def _get_model_status(model_id: str) -> str:
res = server.make_request("GET", "/models")
assert res.status_code == 200
for item in res.body.get("data", []):
if item.get("id") == model_id or item.get("model") == model_id:
return item["status"]["value"]
raise AssertionError(f"Model {model_id} not found in /models response")
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
deadline = time.time() + timeout
last_status = None
while time.time() < deadline:
last_status = _get_model_status(model_id)
if last_status in desired:
return last_status
time.sleep(1)
raise AssertionError(
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
)
def _load_model_and_wait(
model_id: str, timeout: int = 60, headers: dict | None = None
) -> None:
load_res = server.make_request(
"POST", "/models/load", data={"model": model_id}, headers=headers
)
assert load_res.status_code == 200
assert isinstance(load_res.body, dict)
assert load_res.body.get("success") is True
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
def test_router_unload_model():
global server
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
_load_model_and_wait(model_id)
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
assert unload_res.status_code == 200
assert unload_res.body.get("success") is True
_wait_for_model_status(model_id, {"unloaded"})
def test_router_models_max_evicts_lru():
global server
server.models_max = 2
server.start()
candidate_models = [
"ggml-org/tinygemma3-GGUF:Q8_0",
"ggml-org/test-model-stories260K:F32",
"ggml-org/test-model-stories260K-infill:F32",
]
# Load only the first 2 models to fill the cache
first, second, third = candidate_models[:3]
_load_model_and_wait(first, timeout=120)
_load_model_and_wait(second, timeout=120)
# Verify both models are loaded
assert _get_model_status(first) == "loaded"
assert _get_model_status(second) == "loaded"
# Load the third model - this should trigger LRU eviction of the first model
_load_model_and_wait(third, timeout=120)
# Verify eviction: third is loaded, first was evicted
assert _get_model_status(third) == "loaded"
assert _get_model_status(first) == "unloaded"
def test_router_no_models_autoload():
global server
server.no_models_autoload = True
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 400
assert "error" in res.body
_load_model_and_wait(model_id)
success_res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert success_res.status_code == 200
assert "error" not in success_res.body
def test_router_api_key_required():
global server
server.api_key = "sk-router-secret"
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 401
assert res.body.get("error", {}).get("type") == "authentication_error"
_load_model_and_wait(model_id, headers=auth_headers)
authed = server.make_request(
"POST",
"/v1/chat/completions",
headers=auth_headers,
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert authed.status_code == 200
assert "error" not in authed.body
def test_router_reload_models():
"""POST /models/reload re-reads the INI preset and updates the model list."""
global server
preset_path = os.path.join(TMP_DIR, "test_reload.ini")
# Initial preset: two models
with open(preset_path, "w") as f:
f.write(
"[model-reload-a]\n"
"hf-repo = ggml-org/test-model-stories260K\n"
"\n"
"[model-reload-b]\n"
"hf-repo = ggml-org/test-model-stories260K-infill\n"
)
server.models_preset = preset_path
server.start()
ids = _get_model_ids(is_reload=False)
assert "model-reload-a" in ids
assert "model-reload-b" in ids
# Updated preset: remove a, keep b unchanged, add c
with open(preset_path, "w") as f:
f.write(
"[model-reload-b]\n"
"hf-repo = ggml-org/test-model-stories260K-infill\n"
"\n"
"[model-reload-c]\n"
"hf-repo = ggml-org/test-model-stories260K\n"
)
try:
ids = _get_model_ids(is_reload=True)
assert "model-reload-a" not in ids, "removed model should no longer appear"
assert "model-reload-b" in ids, "unchanged model should still appear"
assert "model-reload-c" in ids, "newly added model should appear"
finally:
os.remove(preset_path)

View File

@@ -0,0 +1,136 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
def test_access_public_endpoint(endpoint: str):
global server
server.start()
res = server.make_request("GET", endpoint)
assert res.status_code == 200
assert "error" not in res.body
def test_access_static_assets_without_api_key():
"""Static web UI assets should not require API key authentication (issue #21229)"""
global server
server.start()
for path in ["/", "/bundle.js", "/bundle.css"]:
res = server.make_request("GET", path)
assert res.status_code == 200, f"Expected 200 for {path}, got {res.status_code}"
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
def test_incorrect_api_key(api_key: str):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {api_key}" if api_key else None,
})
assert res.status_code == 401
assert "error" in res.body
assert res.body["error"]["type"] == "authentication_error"
def test_correct_api_key():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_correct_api_key_anthropic_header():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"X-Api-Key": TEST_API_KEY,
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a chatbot."},
{"role": "user", "content": "What is the meaning of life?"},
],
)
assert len(res.choices) == 1
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
("localhost", "Access-Control-Allow-Origin", "localhost"),
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
("origin", "Access-Control-Allow-Credentials", "true"),
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
])
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
global server
server.start()
res = server.make_request("OPTIONS", "/completions", headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization",
})
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value
@pytest.mark.parametrize(
"media_path, image_url, success",
[
(None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
("../../../tools", "file://mtmd/test-1.jpeg", True),
("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
]
)
def test_local_media_file(media_path, image_url, success,):
server = ServerPreset.tinygemma3()
server.media_path = media_path
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": "test"},
{"type": "image_url", "image_url": {
"url": image_url,
}},
]},
],
})
if success:
assert res.status_code == 200
else:
assert res.status_code == 400

View File

@@ -0,0 +1,39 @@
import pytest
import time
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_sleep():
global server
server.sleep_idle_seconds = 1
server.start()
# wait a bit so that server can go to sleep
time.sleep(2)
# make sure these endpoints are still responsive after sleep
res = server.make_request("GET", "/health")
assert res.status_code == 200
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["is_sleeping"] == True
# make a generation request to wake up the server
res = server.make_request("POST", "/completion", data={
"n_predict": 1,
"prompt": "Hello",
})
assert res.status_code == 200
# it should no longer be sleeping
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["is_sleeping"] == False

View File

@@ -0,0 +1,98 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.slot_save_path = "./tmp"
server.temperature = 0.0
def test_slot_save_restore():
global server
server.start()
# First prompt in slot 1 should be fully processed
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# Save state of slot 1
res = server.make_request("POST", "/slots/1?action=save", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_saved"] == 84
# Since we have cache, this should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# Loading the saved cache into slot 0
res = server.make_request("POST", "/slots/0?action=restore", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_restored"] == 84
# Since we have cache, slot 0 should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 1
def test_slot_erase():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# erase slot 1
res = server.make_request("POST", "/slots/1?action=erase")
assert res.status_code == 200
# re-run the same prompt, it should process all tokens again
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

View File

@@ -0,0 +1,131 @@
import pytest
from utils import *
# We use a F16 MOE gguf as main model, and q4_0 as draft model
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
def create_server():
global server
server = ServerPreset.stories15m_moe()
# set default values
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
server.draft_min = 4
server.draft_max = 8
server.fa = "off"
@pytest.fixture(autouse=True)
def fixture_create_server():
return create_server()
def test_with_and_without_draft():
global server
server.model_draft = None # disable draft model
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"n_predict": 16,
})
assert res.status_code == 200
content_no_draft = res.body["content"]
server.stop()
# create new server with draft model
create_server()
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"n_predict": 16,
})
assert res.status_code == 200
content_draft = res.body["content"]
assert content_no_draft == content_draft
def test_different_draft_min_draft_max():
global server
test_values = [
(1, 2),
(1, 4),
(4, 8),
(4, 12),
(8, 16),
]
last_content = None
for draft_min, draft_max in test_values:
server.stop()
server.draft_min = draft_min
server.draft_max = draft_max
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"n_predict": 16,
})
assert res.status_code == 200
if last_content is not None:
assert last_content == res.body["content"]
last_content = res.body["content"]
def test_slot_ctx_not_exceeded():
global server
server.n_ctx = 256
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
def test_with_ctx_shift():
global server
server.n_ctx = 256
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 248,
"temperature": 0.0,
"top_k": 1,
"n_predict": 256,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
assert res.body["tokens_predicted"] == 256
assert res.body["truncated"] == True
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})))
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])

View File

@@ -0,0 +1,106 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
from unit.test_tool_call import TEST_TOOL
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
import datetime
from utils import *
from typing import Literal
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.n_slots = 1
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,reasoning,expected_end", [
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", "on", "<think>\n"),
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B","auto", "<think>\n"),
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", "off", "<think>\n</think>"),
("Qwen-Qwen3-0.6B","auto", "<|im_start|>assistant\n"),
("Qwen-Qwen3-0.6B", "off", "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
("Qwen-QwQ-32B","auto", "<|im_start|>assistant\n<think>\n"),
("Qwen-QwQ-32B", "off", "<|im_start|>assistant\n<think>\n</think>"),
("CohereForAI-c4ai-command-r7b-12-2024-tool_use","auto", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", "off", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
])
def test_reasoning(template_name: str, reasoning: Literal['on', 'off', 'auto'] | None, expected_end: str, tools: list[dict]):
global server
server.jinja = True
server.reasoning = reasoning
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"tools": tools,
})
assert res.status_code == 200
prompt = res.body["prompt"]
assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,format", [
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
])
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"tools": tools,
})
assert res.status_code == 200
prompt = res.body["prompt"]
today_str = datetime.date.today().strftime(format)
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
@pytest.mark.parametrize("add_generation_prompt", [False, True])
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
])
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"add_generation_prompt": add_generation_prompt,
})
assert res.status_code == 200
prompt = res.body["prompt"]
if add_generation_prompt:
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
else:
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"

View File

@@ -0,0 +1,59 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_tokenize_detokenize():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert res_tok.status_code == 200
assert len(res_tok.body["tokens"]) > 5
# detokenize
res_detok = server.make_request("POST", "/detokenize", data={
"tokens": res_tok.body["tokens"],
})
assert res_detok.status_code == 200
assert res_detok.body["content"].strip() == content
def test_tokenize_with_bos():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
bosId = 1
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert res_tok.status_code == 200
assert res_tok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
global server
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert res_tok.status_code == 200
for token in res_tok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token
assert len(token["piece"]) > 0

View File

@@ -0,0 +1,645 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
from utils import *
from enum import Enum
from typing import TypedDict
server: ServerProcess
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
TIMEOUT_HTTP_REQUEST = 60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048
class CompletionMode(Enum):
NORMAL = "normal"
STREAMED = "streamed"
class ToolParameters(TypedDict):
type: str
properties: dict[str, dict]
required: list[str]
class ToolFunction(TypedDict):
name: str
description: str
parameters: ToolParameters
class ToolDefinition(TypedDict):
type: str
function: ToolFunction
TEST_TOOL = ToolDefinition(
type = "function",
function = ToolFunction(
name = "test",
description = "",
parameters = ToolParameters(
type = "object",
properties = {
"success": {
"type": "boolean",
"const": True,
},
},
required = ["success"],
),
),
)
PYTHON_TOOL = ToolDefinition(
type = "function",
function = ToolFunction(
name = "python",
description = "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
parameters = ToolParameters(
type = "object",
properties = {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter.",
},
},
required = ["code"],
),
),
)
WEATHER_TOOL = ToolDefinition(
type = "function",
function = ToolFunction(
name = "get_current_weather",
description = "Get the current weather in a given location",
parameters = ToolParameters(
type = "object",
properties = {
"location": {
"type": "string",
"description": "The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'",
},
},
required = ["location"],
),
),
)
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
**kwargs,
})
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"], f'Expected tool name to be {tool_call["function"]["name"]} in {choice["message"]}'
actual_arguments = tool_call["function"]["arguments"]
assert isinstance(actual_arguments, dict) or isinstance(actual_arguments, str), f'Expected arguments to be a dict or str, got: {actual_arguments}'
if argument_key is not None:
if (isinstance(actual_arguments, str)):
actual_arguments = json.loads(actual_arguments)
assert argument_key in actual_arguments, f"tool arguments: {actual_arguments}, expected: {argument_key}"
# PR #22654: commented out since we're now allowing content before tool calls in tool_call: required, so we can't force this
# in the tiny model just by using the grammar
#
# @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
# @pytest.mark.parametrize("template_name,tool,argument_key", [
# ("Qwen3-Coder", TEST_TOOL, "success"),
# ("Qwen3-Coder", TEST_TOOL, "success"),
# ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
# ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
# ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
# ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
# ])
# def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
# global server
# n_predict = 1024
# # server = ServerPreset.stories15m_moe()
# server.jinja = True
# server.n_predict = n_predict
# server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
# server.start()
# do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
# @pytest.mark.slow
# @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
# @pytest.mark.parametrize("template_name,tool,argument_key", [
# ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
# ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
# ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
# ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
# ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
# # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
# # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
# ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
# ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
# ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
# ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
# ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
# ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
# ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
# ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
# ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
# ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
# ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
# # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
# # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
# ])
# def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
# global server
# n_predict = 512
# # server = ServerPreset.stories15m_moe()
# server.jinja = True
# server.n_predict = n_predict
# server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
# server.start(timeout_seconds=TIMEOUT_START_SLOW)
# do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_START_SLOW)
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
"stream": stream == CompletionMode.STREAMED,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"]
assert isinstance(actual_arguments, str)
if argument_key is not None:
actual_arguments = json.loads(actual_arguments)
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "say hello world with python"},
],
"tools": tools if tools else None,
"tool_choice": tool_choice,
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meetkai-functionary-medium-v3.2", 256, [], None),
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
("meetkai-functionary-medium-v3.1", 256, [], None),
("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
])
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start()
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_weather(server: ServerProcess, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"},
],
"tools": [WEATHER_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
location = actual_arguments["location"]
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_6789",
"type": "function",
"function": {
"name": "calculate",
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
}
}
]
},
{
"role": "tool",
"name": "calculate",
"content": "0.55644242476",
"tool_call_id": "call_6789"
}
],
"tools": [
{
"type":"function",
"function":{
"name":"calculate",
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
"parameters":{
"type":"object",
"properties":{
"expression":{
"type":"string",
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
}
},
"required":["expression"]
}
}
}
],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
assert content is not None, f'Expected content in {choice["message"]}'
if result_override is not None:
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
else:
assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
f'Expected something like "The y coordinate is 0.56.", got {content}'
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [
(128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.reasoning_format = reasoning_format
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start()
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"},
],
"stream": stream == CompletionMode.STREAMED,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
if expect_content is None:
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
else:
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
reasoning_content = choice["message"].get("reasoning_content")
if expect_reasoning_content is None:
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
else:
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
# ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
])
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512 # High because of DeepSeek R1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_hello_world(server: ServerProcess, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a tool-calling agent."},
{"role": "user", "content": "say hello world with python"},
],
"tools": [PYTHON_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'

View File

@@ -0,0 +1,161 @@
import pytest
from utils import *
import base64
import requests
server: ServerProcess
def get_img_url(id: str) -> str:
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
if id == "IMG_URL_0":
return IMG_URL_0
elif id == "IMG_URL_1":
return IMG_URL_1
elif id == "IMG_BASE64_URI_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_URI_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
else:
return id
JSON_MULTIMODAL_KEY = "multimodal_data"
JSON_PROMPT_STRING_KEY = "prompt_string"
@pytest.fixture(autouse=True)
def create_server():
global server
os.environ['LLAMA_MEDIA_MARKER'] = '<__media__>'
server = ServerPreset.tinygemma3()
def test_models_supports_multimodal_capability():
global server
server.start()
res = server.make_request("GET", "/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
print(model_info)
assert "completion" in model_info["capabilities"]
assert "multimodal" in model_info["capabilities"]
def test_v1_models_supports_multimodal_capability():
global server
server.start()
res = server.make_request("GET", "/v1/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
print(model_info)
assert "completion" in model_info["capabilities"]
assert "multimodal" in model_info["capabilities"]
@pytest.mark.parametrize(
"prompt, image_url, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
# TODO @ngxson : test with multiple images, no images and with audio
]
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"temperature": 0.0,
"top_k": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": get_img_url(image_url),
}},
]},
],
})
if success:
assert res.status_code == 200
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
@pytest.mark.parametrize(
"prompt, image_data, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
("What is this: <__media__>\n", "malformed", False, None), # non-image data
("What is this:\n", "", False, None), # empty string
]
)
def test_vision_completion(prompt, image_data, success, re_content):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"temperature": 0.0,
"top_k": 1,
"prompt": {
JSON_PROMPT_STRING_KEY: prompt,
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
},
})
if success:
assert res.status_code == 200
content = res.body["content"]
assert match_regex(re_content, content)
else:
assert res.status_code != 200
@pytest.mark.parametrize(
"prompt, image_data, success",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", "IMG_BASE64_0", True),
("What is this: <__media__>\n", "IMG_BASE64_1", True),
("What is this: <__media__>\n", "malformed", False), # non-image data
("What is this:\n", "base64", False), # non-image data
]
)
def test_vision_embeddings(prompt, image_data, success):
global server
server.server_embeddings = True
server.n_batch = 512
server.start()
image_data = get_img_url(image_data)
res = server.make_request("POST", "/embeddings", data={
"content": [
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
{ JSON_PROMPT_STRING_KEY: prompt, },
],
})
if success:
assert res.status_code == 200
content = res.body
# Ensure embeddings are stable when multimodal.
assert content[0]['embedding'] == content[1]['embedding']
# Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
assert content[0]['embedding'] != content[2]['embedding']
else:
assert res.status_code != 200

682
tools/server/tests/utils.py Normal file
View File

@@ -0,0 +1,682 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# type: ignore[reportUnusedImport]
import subprocess
import os
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
import re
import json
from json import JSONDecodeError
import sys
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
Any,
Callable,
ContextManager,
Iterable,
Iterator,
List,
Literal,
Tuple,
Set,
)
from re import RegexFlag
import wget
DEFAULT_HTTP_TIMEOUT = 60
class ServerResponse:
headers: dict
status_code: int
body: dict | Any
class ServerError(Exception):
def __init__(self, code, body):
self.code = code
self.body = body
class ServerProcess:
# default options
debug: bool = False
server_port: int = 8080
server_host: str = "127.0.0.1"
model_hf_repo: str | None = "ggml-org/models"
model_hf_file: str | None = "tinyllamas/stories260K.gguf"
model_alias: str = "tinyllama-2"
temperature: float = 0.8
seed: int = 42
offline: bool = False
# custom options
model_alias: str | None = None
model_tags: str | None = None
model_url: str | None = None
model_file: str | None = None
model_draft: str | None = None
n_threads: int | None = None
n_gpu_layer: int | None = None
n_batch: int | None = None
n_ubatch: int | None = None
n_ctx: int | None = None
n_ga: int | None = None
n_ga_w: int | None = None
n_predict: int | None = None
n_prompts: int | None = 0
slot_save_path: str | None = None
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
ctk: str | None = None
ctv: str | None = None
fa: str | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
kv_unified: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
api_key: str | None = None
models_dir: str | None = None
models_max: int | None = None
models_preset: str | None = None
no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
spec_draft_n_min: int | None = None
spec_draft_n_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
reasoning: Literal['on', 'off', 'auto'] | None = None
chat_template: str | None = None
chat_template_file: str | None = None
server_path: str | None = None
mmproj_url: str | None = None
media_path: str | None = None
sleep_idle_seconds: int | None = None
cache_ram: int | None = None
no_cache_idle_slots: bool = False
log_path: str | None = None
webui_mcp_proxy: bool = False
backend_sampling: bool = False
gcp_compat: bool = False
# session variables
process: subprocess.Popen | None = None
def __init__(self):
if "N_GPU_LAYERS" in os.environ:
self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
if "DEBUG" in os.environ:
self.debug = True
if "PORT" in os.environ:
self.server_port = int(os.environ["PORT"])
self.external_server = "DEBUG_EXTERNAL" in os.environ
def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
env = {**os.environ}
if "LLAMA_CACHE" not in os.environ:
env["LLAMA_CACHE"] = "tmp"
if self.external_server:
print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
return
if self.server_path is not None:
server_path = self.server_path
elif "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt":
server_path = "../../../build/bin/Release/llama-server.exe"
else:
server_path = "../../../build/bin/llama-server"
server_args = [
"--host",
self.server_host,
"--port",
self.server_port,
"--temp",
self.temperature,
"--seed",
self.seed,
]
if self.offline:
server_args.append("--offline")
if self.model_file:
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
if self.model_draft:
server_args.extend(["--model-draft", self.model_draft])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.models_dir:
server_args.extend(["--models-dir", self.models_dir])
if self.models_max is not None:
server_args.extend(["--models-max", self.models_max])
if self.models_preset:
server_args.extend(["--models-preset", self.models_preset])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
server_args.extend(["--ubatch-size", self.n_ubatch])
if self.n_threads:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
server_args.append("--embedding")
if self.server_reranking:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.kv_unified:
server_args.append("--kv-unified")
if self.server_slots:
server_args.append("--slots")
else:
server_args.append("--no-slots")
if self.pooling:
server_args.extend(["--pooling", self.pooling])
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.model_tags:
server_args.extend(["--tags", self.model_tags])
if self.n_ctx:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.ctk:
server_args.extend(["-ctk", self.ctk])
if self.ctv:
server_args.extend(["-ctv", self.ctv])
if self.fa is not None:
server_args.extend(["-fa", self.fa])
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.n_ga:
server_args.extend(["--grp-attn-n", self.n_ga])
if self.n_ga_w:
server_args.extend(["--grp-attn-w", self.n_ga_w])
if self.debug:
server_args.append("--verbose")
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.enable_ctx_shift:
server_args.append("--context-shift")
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.spec_draft_n_max:
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
if self.spec_draft_n_min:
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
if self.no_webui:
server_args.append("--no-webui")
if self.no_models_autoload:
server_args.append("--no-models-autoload")
if self.jinja:
server_args.append("--jinja")
else:
server_args.append("--no-jinja")
if self.reasoning_format is not None:
server_args.extend(("--reasoning-format", self.reasoning_format))
if self.reasoning is not None:
server_args.extend(("--reasoning", self.reasoning))
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.mmproj_url:
server_args.extend(["--mmproj-url", self.mmproj_url])
if self.media_path:
server_args.extend(["--media-path", self.media_path])
if self.sleep_idle_seconds is not None:
server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
if self.cache_ram is not None:
server_args.extend(["--cache-ram", self.cache_ram])
if self.no_cache_idle_slots:
server_args.append("--no-cache-idle-slots")
if self.webui_mcp_proxy:
server_args.append("--webui-mcp-proxy")
if self.backend_sampling:
server_args.append("--backend_sampling")
if self.gcp_compat:
env["AIP_MODE"] = "PREDICTION"
args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
flags = 0
if "nt" == os.name:
flags |= subprocess.DETACHED_PROCESS
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
flags |= subprocess.CREATE_NO_WINDOW
if self.log_path:
self._log = open(self.log_path, "w")
else:
self._log = sys.stdout
self.process = subprocess.Popen(
[str(arg) for arg in [server_path, *server_args]],
creationflags=flags,
stdout=self._log,
stderr=self._log if self._log != sys.stdout else sys.stdout,
env=env,
)
server_instances.add(self)
print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
# wait for server to start
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = self.make_request("GET", "/health", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
})
if response.status_code == 200:
self.ready = True
return # server is ready
except Exception as e:
pass
# Check if process died
if self.process.poll() is not None:
raise RuntimeError(f"Server process died with return code {self.process.returncode}")
print(f"Waiting for server to start...")
time.sleep(0.5)
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None:
if self.external_server:
print("[external_server]: Not stopping external server")
return
if self in server_instances:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"Server pid={self.process.pid} did not terminate in time, killing")
self.process.kill()
self.process.wait(timeout=5)
except Exception as e:
print(f"Error waiting for server: {e}")
self.process = None
if hasattr(self, '_log') and self._log != sys.stdout:
self._log.close()
def make_request(
self,
method: str,
path: str,
data: dict | Any | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
parse_body = False
if method == "GET":
response = requests.get(url, headers=headers, timeout=timeout)
parse_body = True
elif method == "POST":
response = requests.post(url, headers=headers, json=data, timeout=timeout)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
if parse_body:
try:
result.body = response.json()
except JSONDecodeError:
result.body = response.text
else:
result.body = None
print("Response from server", json.dumps(result.body, indent=2))
return result
def make_stream_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
) -> Iterator[dict]:
url = f"http://{self.server_host}:{self.server_port}{path}"
if method == "POST":
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
if response.status_code != 200:
raise ServerError(response.status_code, response.json())
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
break
elif line.startswith('data: '):
data = json.loads(line[6:])
print("Partial response from server", json.dumps(data, indent=2))
yield data
def make_any_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> dict:
stream = data.get('stream', False)
if stream:
content: list[str] = []
reasoning_content: list[str] = []
tool_calls: list[dict] = []
finish_reason: Optional[str] = None
content_parts = 0
reasoning_content_parts = 0
tool_call_parts = 0
arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers):
if chunk['choices']:
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
choice = chunk['choices'][0]
if choice['delta'].get('content') is not None:
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
content.append(choice['delta']['content'])
content_parts += 1
if choice['delta'].get('reasoning_content') is not None:
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
reasoning_content.append(choice['delta']['reasoning_content'])
reasoning_content_parts += 1
if choice['delta'].get('finish_reason') is not None:
finish_reason = choice['delta']['finish_reason']
for tc in choice['delta'].get('tool_calls', []):
if 'function' not in tc:
raise ValueError(f"Expected function type, got {tc['type']}")
if tc['index'] >= len(tool_calls):
assert 'id' in tc
assert tc.get('type') == 'function'
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
f"Expected function call with name, got {tc.get('function')}"
tool_calls.append(dict(
id="",
type="function",
function=dict(
name="",
arguments="",
)
))
tool_call = tool_calls[tc['index']]
if tc.get('id') is not None:
tool_call['id'] = tc['id']
fct = tc['function']
assert 'id' not in fct, f"Function call should not have id: {fct}"
if fct.get('name') is not None:
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
if fct.get('arguments') is not None:
tool_call['function']['arguments'] += fct['arguments']
arguments_parts += 1
tool_call_parts += 1
else:
# When `include_usage` is True (the default), we expect the last chunk of the stream
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
# the last chunk)
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict(
choices=[
dict(
index=0,
finish_reason=finish_reason,
message=dict(
role='assistant',
content=''.join(content) if content else None,
reasoning_content=''.join(reasoning_content) if reasoning_content else None,
tool_calls=tool_calls if tool_calls else None,
),
)
],
)
print("Final response from server", json.dumps(result, indent=2))
return result
else:
response = self.make_request(method, path, data, headers, timeout=timeout)
assert response.status_code == 200, f"Server returned error: {response.status_code}"
return response.body
server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def load_all() -> None:
""" Load all server presets to ensure model files are cached. """
servers: List[ServerProcess] = [
method()
for name, method in ServerPreset.__dict__.items()
if callable(method) and name != "load_all"
]
for server in servers:
server.offline = False
server.start()
server.stop()
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/test-model-stories260K"
server.model_hf_file = None
server.model_alias = "tinyllama-2"
server.n_ctx = 512
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
server.seed = 42
return server
@staticmethod
def bert_bge_small() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 512
server.n_batch = 128
server.n_ubatch = 128
server.n_slots = 2
server.seed = 42
server.server_embeddings = True
return server
@staticmethod
def bert_bge_small_with_fa() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 1024
server.n_batch = 300
server.n_ubatch = 300
server.n_slots = 2
server.fa = "on"
server.seed = 42
server.server_embeddings = True
return server
@staticmethod
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
server.model_hf_file = None
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def stories15m_moe() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/stories15M_MOE"
server.model_hf_file = "stories15M_MOE-F16.gguf"
server.model_alias = "stories15m-moe"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def jina_reranker_tiny() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
server.model_alias = "jina-reranker"
server.n_ctx = 512
server.n_batch = 512
server.n_slots = 1
server.seed = 42
server.server_reranking = True
return server
@staticmethod
def tinygemma3() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
# mmproj is already provided by HF registry API
server.model_hf_file = None
server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
server.model_alias = "tinygemma3"
server.n_ctx = 1024
server.n_batch = 32
server.n_slots = 2
server.n_predict = 4
server.seed = 42
return server
@staticmethod
def router() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
# router server has no models
server.model_file = None
server.model_alias = None
server.model_hf_repo = None
server.model_hf_file = None
server.n_ctx = 1024
server.n_batch = 16
server.n_slots = 1
server.n_predict = 16
server.seed = 42
return server
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
"""
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
Example usage:
results = parallel_function_calls([
(func1, (arg1, arg2)),
(func2, (arg3, arg4)),
])
"""
results = [None] * len(function_list)
exceptions = []
def worker(index, func, args):
try:
result = func(*args)
results[index] = result
except Exception as e:
exceptions.append((index, str(e)))
with ThreadPoolExecutor() as executor:
futures = []
for i, (func, args) in enumerate(function_list):
future = executor.submit(worker, i, func, args)
futures.append(future)
# Wait for all futures to complete
for future in as_completed(futures):
pass
# Check if there were any exceptions
if exceptions:
print("Exceptions occurred:")
for index, error in exceptions:
print(f"Function at index {index}: {error}")
return results
def match_regex(regex: str, text: str) -> bool:
return (
re.compile(
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
).search(text)
is not None
)
def download_file(url: str, output_file_path: str | None = None) -> str:
"""
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
Returns the local path of the downloaded file.
"""
file_name = url.split('/').pop()
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
if not os.path.exists(output_file):
print(f"Downloading {url} to {output_file}")
wget.download(url, out=output_file)
print(f"Done downloading to {output_file}")
else:
print(f"File already exists at {output_file}")
return output_file
def is_slow_test_allowed():
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

28
tools/server/webui/.gitignore vendored Normal file
View File

@@ -0,0 +1,28 @@
test-results
node_modules
# Output
.output
.vercel
.netlify
.wrangler
/.svelte-kit
/build
# OS
.DS_Store
Thumbs.db
# Env
.env
.env.*
!.env.example
!.env.test
# Vite
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
*storybook.log
storybook-static
*.code-workspace

View File

@@ -0,0 +1 @@
engine-strict=true

View File

@@ -0,0 +1,9 @@
# Package Managers
package-lock.json
pnpm-lock.yaml
yarn.lock
bun.lock
bun.lockb
# Miscellaneous
/static/

View File

@@ -0,0 +1,16 @@
{
"useTabs": true,
"singleQuote": true,
"trailingComma": "none",
"printWidth": 100,
"plugins": ["prettier-plugin-svelte", "prettier-plugin-tailwindcss"],
"overrides": [
{
"files": "*.svelte",
"options": {
"parser": "svelte"
}
}
],
"tailwindStylesheet": "./src/app.css"
}

View File

@@ -0,0 +1,36 @@
<script lang="ts">
import { ModeWatcher } from 'mode-watcher';
import { onMount } from 'svelte';
interface Props {
children?: any;
}
let { children }: Props = $props();
onMount(() => {
const root = document.documentElement;
const theme = localStorage.getItem('mode-watcher-mode') || 'system';
if (theme === 'dark') {
root.classList.add('dark');
} else if (theme === 'light') {
root.classList.remove('dark');
} else {
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
if (prefersDark) {
root.classList.add('dark');
} else {
root.classList.remove('dark');
}
}
});
</script>
<ModeWatcher />
{#if children}
{@const Component = children}
<Component />
{/if}

View File

@@ -0,0 +1,13 @@
<script lang="ts">
import * as Tooltip from '../../src/lib/components/ui/tooltip';
interface Props {
children: any;
}
let { children }: Props = $props();
</script>
<Tooltip.Provider>
{@render children()}
</Tooltip.Provider>

View File

@@ -0,0 +1,24 @@
import type { StorybookConfig } from '@storybook/sveltekit';
import { dirname, resolve } from 'path';
import { fileURLToPath } from 'url';
const __dirname = dirname(fileURLToPath(import.meta.url));
const config: StorybookConfig = {
stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'],
addons: [
'@storybook/addon-svelte-csf',
'@chromatic-com/storybook',
'@storybook/addon-vitest',
'@storybook/addon-a11y',
'@storybook/addon-docs'
],
framework: '@storybook/sveltekit',
viteFinal: async (config) => {
config.server = config.server || {};
config.server.fs = config.server.fs || {};
config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')];
return config;
}
};
export default config;

View File

@@ -0,0 +1,42 @@
import type { Preview } from '@storybook/sveltekit';
import '../src/app.css';
import ModeWatcherDecorator from './decorators/ModeWatcherDecorator.svelte';
import TooltipProviderDecorator from './decorators/TooltipProviderDecorator.svelte';
const preview: Preview = {
parameters: {
controls: {
matchers: {
color: /(background|color)$/i,
date: /Date$/i
}
},
backgrounds: {
disabled: true
},
a11y: {
// 'todo' - show a11y violations in the test UI only
// 'error' - fail CI on a11y violations
// 'off' - skip a11y checks entirely
test: 'todo'
}
},
decorators: [
(story) => ({
Component: ModeWatcherDecorator,
props: {
children: story
}
}),
(story) => ({
Component: TooltipProviderDecorator,
props: {
children: story
}
})
]
};
export default preview;

View File

@@ -0,0 +1,12 @@
import * as a11yAddonAnnotations from '@storybook/addon-a11y/preview';
import { setProjectAnnotations } from '@storybook/sveltekit';
import * as previewAnnotations from './preview';
import { beforeAll } from 'vitest';
const project = setProjectAnnotations([a11yAddonAnnotations, previewAnnotations]);
beforeAll(async () => {
if (project.beforeAll) {
await project.beforeAll();
}
});

View File

@@ -0,0 +1,687 @@
# llama.cpp Web UI
A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities.
The WebUI supports two server operation modes:
- **MODEL mode** - Single model operation (standard llama-server)
- **ROUTER mode** - Multi-model operation with dynamic model loading/unloading
---
## Table of Contents
- [Features](#features)
- [Getting Started](#getting-started)
- [Tech Stack](#tech-stack)
- [Build Pipeline](#build-pipeline)
- [Architecture](#architecture)
- [Data Flows](#data-flows)
- [Architectural Patterns](#architectural-patterns)
- [Testing](#testing)
---
## Features
### Chat Interface
- **Streaming responses** with real-time updates
- **Reasoning content** - Support for models with thinking/reasoning blocks
- **Dark/light theme** with system preference detection
- **Responsive design** for desktop and mobile
### File Attachments
- **Images** - JPEG, PNG, GIF, WebP, SVG (with PNG conversion)
- **Documents** - PDF (text extraction or image conversion for vision models)
- **Audio** - MP3, WAV for audio-capable models
- **Text files** - Source code, markdown, and other text formats
- **Drag-and-drop** and paste support with rich previews
### Conversation Management
- **Branching** - Branch messages conversations at any point by editing messages or regenerating responses, navigate between branches
- **Regeneration** - Regenerate responses with optional model switching (ROUTER mode)
- **Import/Export** - JSON format for backup and sharing
- **Search** - Find conversations by title or content
### Advanced Rendering
- **Syntax highlighting** - Code blocks with language detection
- **Math formulas** - KaTeX rendering for LaTeX expressions
- **Markdown** - Full GFM support with tables, lists, and more
### Multi-Model Support (ROUTER mode)
- **Model selector** with Loaded/Available groups
- **Automatic loading** - Models load on selection
- **Modality validation** - Prevents sending images to non-vision models
- **LRU unloading** - Server auto-manages model cache
### Keyboard Shortcuts
| Shortcut | Action |
| ------------------ | -------------------- |
| `Shift+Ctrl/Cmd+O` | New chat |
| `Shift+Ctrl/Cmd+E` | Edit conversation |
| `Shift+Ctrl/Cmd+D` | Delete conversation |
| `Ctrl/Cmd+K` | Search conversations |
| `Ctrl/Cmd+B` | Toggle sidebar |
### Developer Experience
- **Request tracking** - Monitor token generation with `/slots` endpoint
- **Storybook** - Component library with visual testing
- **Hot reload** - Instant updates during development
---
## Getting Started
### Prerequisites
- **Node.js** 18+ (20+ recommended)
- **npm** 9+
- **llama-server** running locally (for API access)
### 1. Install Dependencies
```bash
cd tools/server/webui
npm install
```
### 2. Start llama-server
In a separate terminal, start the backend server:
```bash
# Single model (MODEL mode)
./llama-server -m model.gguf
# Multi-model (ROUTER mode)
./llama-server --models-dir /path/to/models
```
### 3. Start Development Servers
```bash
npm run dev
```
This starts:
- **Vite dev server** at `http://localhost:5173` - The main WebUI
- **Storybook** at `http://localhost:6006` - Component documentation
The Vite dev server proxies API requests to `http://localhost:8080` (default llama-server port):
```typescript
// vite.config.ts proxy configuration
proxy: {
'/v1': 'http://localhost:8080',
'/props': 'http://localhost:8080',
'/slots': 'http://localhost:8080',
'/models': 'http://localhost:8080'
}
```
### Development Workflow
1. Open `http://localhost:5173` in your browser
2. Make changes to `.svelte`, `.ts`, or `.css` files
3. Changes hot-reload instantly
4. Use Storybook at `http://localhost:6006` for isolated component development
---
## Tech Stack
| Layer | Technology | Purpose |
| ----------------- | ------------------------------- | -------------------------------------------------------- |
| **Framework** | SvelteKit + Svelte 5 | Reactive UI with runes (`$state`, `$derived`, `$effect`) |
| **UI Components** | shadcn-svelte + bits-ui | Accessible, customizable component library |
| **Styling** | TailwindCSS 4 | Utility-first CSS with design tokens |
| **Database** | IndexedDB (Dexie) | Client-side storage for conversations and messages |
| **Build** | Vite | Fast bundling with static adapter |
| **Testing** | Playwright + Vitest + Storybook | E2E, unit, and visual testing |
| **Markdown** | remark + rehype | Markdown processing with KaTeX and syntax highlighting |
### Key Dependencies
```json
{
"svelte": "^5.0.0",
"bits-ui": "^2.8.11",
"dexie": "^4.0.11",
"pdfjs-dist": "^5.4.54",
"highlight.js": "^11.11.1",
"rehype-katex": "^7.0.1"
}
```
---
## Build Pipeline
### Development Build
```bash
npm run dev
```
Runs Vite in development mode with:
- Hot Module Replacement (HMR)
- Source maps
- Proxy to llama-server
### Production Build
```bash
npm run build
```
The build process:
1. **Vite Build** - Bundles all TypeScript, Svelte, and CSS
2. **Static Adapter** - Outputs to `../public` (llama-server's static file directory)
3. **Post-Build Script** - Cleans up intermediate files
4. **Custom Plugin** - Creates `index.html` with:
- Inlined favicon as base64
- GZIP compression (level 9)
- Deterministic output (zeroed timestamps)
```text
tools/server/webui/ → build → tools/server/public/
├── src/ ├── index.html (served by llama-server)
├── static/ └── (favicon inlined)
└── ...
```
### SvelteKit Configuration
```javascript
// svelte.config.js
adapter: adapter({
pages: '../public', // Output directory
assets: '../public', // Static assets
fallback: 'index.html', // SPA fallback
strict: true
}),
output: {
bundleStrategy: 'inline' // Single-file bundle
}
```
### Integration with llama-server
The WebUI is embedded directly into the llama-server binary:
1. `npm run build` outputs `index.html` to `tools/server/public/`
2. llama-server compiles this into the binary at build time
3. When accessing `/`, llama-server serves the gzipped HTML
4. All assets are inlined (CSS, JS, fonts, favicon)
This results in a **single portable binary** with the full WebUI included.
---
## Architecture
The WebUI follows a layered architecture with unidirectional data flow:
```text
Routes → Components → Hooks → Stores → Services → Storage/API
```
### High-Level Architecture
See: [`docs/architecture/high-level-architecture-simplified.md`](docs/architecture/high-level-architecture-simplified.md)
```mermaid
flowchart TB
subgraph Routes["📍 Routes"]
R1["/ (Welcome)"]
R2["/chat/[id]"]
RL["+layout.svelte"]
end
subgraph Components["🧩 Components"]
C_Sidebar["ChatSidebar"]
C_Screen["ChatScreen"]
C_Form["ChatForm"]
C_Messages["ChatMessages"]
C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"]
end
subgraph Stores["🗄️ Stores"]
S1["chatStore"]
S2["conversationsStore"]
S3["modelsStore"]
S4["serverStore"]
S5["settingsStore"]
end
subgraph Services["⚙️ Services"]
SV1["ChatService"]
SV2["ModelsService"]
SV3["PropsService"]
SV4["DatabaseService"]
end
subgraph Storage["💾 Storage"]
ST1["IndexedDB"]
ST2["LocalStorage"]
end
subgraph APIs["🌐 llama-server"]
API1["/v1/chat/completions"]
API2["/props"]
API3["/models/*"]
end
R1 & R2 --> C_Screen
RL --> C_Sidebar
C_Screen --> C_Form & C_Messages & C_Settings
C_Screen --> S1 & S2
C_ModelsSelector --> S3 & S4
S1 --> SV1 & SV4
S3 --> SV2 & SV3
SV4 --> ST1
SV1 --> API1
SV2 --> API3
SV3 --> API2
```
### Layer Breakdown
#### Routes (`src/routes/`)
- **`/`** - Welcome screen, creates new conversation
- **`/chat/[id]`** - Active chat interface
- **`+layout.svelte`** - Sidebar, navigation, global initialization
#### Components (`src/lib/components/`)
Components are organized in `app/` (application-specific) and `ui/` (shadcn-svelte primitives).
**Chat Components** (`app/chat/`):
| Component | Responsibility |
| ------------------ | --------------------------------------------------------------------------- |
| `ChatScreen/` | Main chat container, coordinates message list, input form, and attachments |
| `ChatForm/` | Message input textarea with file upload, paste handling, keyboard shortcuts |
| `ChatMessages/` | Message list with branch navigation, regenerate/continue/edit actions |
| `ChatAttachments/` | File attachment previews, drag-and-drop, PDF/image/audio handling |
| `ChatSettings/` | Parameter sliders (temperature, top-p, etc.) with server default sync |
| `ChatSidebar/` | Conversation list, search, import/export, navigation |
**Dialog Components** (`app/dialogs/`):
| Component | Responsibility |
| ------------------------------- | -------------------------------------------------------- |
| `DialogChatSettings` | Full-screen settings configuration |
| `DialogModelInformation` | Model details (context size, modalities, parallel slots) |
| `DialogChatAttachmentPreview` | Full preview for images, PDFs (text or page view), code |
| `DialogConfirmation` | Generic confirmation for destructive actions |
| `DialogConversationTitleUpdate` | Edit conversation title |
**Server/Model Components** (`app/server/`, `app/models/`):
| Component | Responsibility |
| ------------------- | --------------------------------------------------------- |
| `ServerErrorSplash` | Error display when server is unreachable |
| `ModelsSelector` | Model dropdown with Loaded/Available groups (ROUTER mode) |
**Shared UI Components** (`app/misc/`):
| Component | Responsibility |
| -------------------------------- | ---------------------------------------------------------------- |
| `MarkdownContent` | Markdown rendering with KaTeX, syntax highlighting, copy buttons |
| `SyntaxHighlightedCode` | Code blocks with language detection and highlighting |
| `ActionButton`, `ActionDropdown` | Reusable action buttons and menus |
| `BadgeModality`, `BadgeInfo` | Status and capability badges |
#### Hooks (`src/lib/hooks/`)
- **`useModelChangeValidation`** - Validates model switch against conversation modalities
- **`useProcessingState`** - Tracks streaming progress and token generation
#### Stores (`src/lib/stores/`)
| Store | Responsibility |
| -------------------- | --------------------------------------------------------- |
| `chatStore` | Message sending, streaming, abort control, error handling |
| `conversationsStore` | CRUD for conversations, message branching, navigation |
| `modelsStore` | Model list, selection, loading/unloading (ROUTER) |
| `serverStore` | Server properties, role detection, modalities |
| `settingsStore` | User preferences, parameter sync with server defaults |
#### Services (`src/lib/services/`)
| Service | Responsibility |
| ---------------------- | ----------------------------------------------- |
| `ChatService` | API calls to`/v1/chat/completions`, SSE parsing |
| `ModelsService` | `/models`, `/models/load`, `/models/unload` |
| `PropsService` | `/props`, `/props?model=` |
| `DatabaseService` | IndexedDB operations via Dexie |
| `ParameterSyncService` | Syncs settings with server defaults |
---
## Data Flows
### MODEL Mode (Single Model)
See: [`docs/flows/data-flow-simplified-model-mode.md`](docs/flows/data-flow-simplified-model-mode.md)
```mermaid
sequenceDiagram
participant User
participant UI
participant Stores
participant DB as IndexedDB
participant API as llama-server
Note over User,API: Initialization
UI->>Stores: initialize()
Stores->>DB: load conversations
Stores->>API: GET /props
API-->>Stores: server config
Stores->>API: GET /v1/models
API-->>Stores: single model (auto-selected)
Note over User,API: Chat Flow
User->>UI: send message
Stores->>DB: save user message
Stores->>API: POST /v1/chat/completions (stream)
loop streaming
API-->>Stores: SSE chunks
Stores-->>UI: reactive update
end
Stores->>DB: save assistant message
```
### ROUTER Mode (Multi-Model)
See: [`docs/flows/data-flow-simplified-router-mode.md`](docs/flows/data-flow-simplified-router-mode.md)
```mermaid
sequenceDiagram
participant User
participant UI
participant Stores
participant API as llama-server
Note over User,API: Initialization
Stores->>API: GET /props
API-->>Stores: {role: "router"}
Stores->>API: GET /models
API-->>Stores: models[] with status
Note over User,API: Model Selection
User->>UI: select model
alt model not loaded
Stores->>API: POST /models/load
loop poll status
Stores->>API: GET /models
end
Stores->>API: GET /props?model=X
end
Stores->>Stores: validate modalities
Note over User,API: Chat Flow
Stores->>API: POST /v1/chat/completions {model: X}
loop streaming
API-->>Stores: SSE chunks + model info
end
```
### Detailed Flow Diagrams
| Flow | Description | File |
| ------------- | ------------------------------------------ | ----------------------------------------------------------- |
| Chat | Message lifecycle, streaming, regeneration | [`chat-flow.md`](docs/flows/chat-flow.md) |
| Models | Loading, unloading, modality caching | [`models-flow.md`](docs/flows/models-flow.md) |
| Server | Props fetching, role detection | [`server-flow.md`](docs/flows/server-flow.md) |
| Conversations | CRUD, branching, import/export | [`conversations-flow.md`](docs/flows/conversations-flow.md) |
| Database | IndexedDB schema, operations | [`database-flow.md`](docs/flows/database-flow.md) |
| Settings | Parameter sync, user overrides | [`settings-flow.md`](docs/flows/settings-flow.md) |
---
## Architectural Patterns
### 1. Reactive State with Svelte 5 Runes
All stores use Svelte 5's fine-grained reactivity:
```typescript
// Store with reactive state
class ChatStore {
#isLoading = $state(false);
#currentResponse = $state('');
// Derived values auto-update
get isStreaming() {
return $derived(this.#isLoading && this.#currentResponse.length > 0);
}
}
// Exported reactive accessors
export const isLoading = () => chatStore.isLoading;
export const currentResponse = () => chatStore.currentResponse;
```
### 2. Unidirectional Data Flow
Data flows in one direction, making state predictable:
```mermaid
flowchart LR
subgraph UI["UI Layer"]
A[User Action] --> B[Component]
end
subgraph State["State Layer"]
B --> C[Store Method]
C --> D[State Update]
end
subgraph IO["I/O Layer"]
C --> E[Service]
E --> F[API / IndexedDB]
F -.->|Response| D
end
D -->|Reactive| B
```
Components dispatch actions to stores, stores coordinate with services for I/O, and state updates reactively propagate back to the UI.
### 3. Per-Conversation State
Enables concurrent streaming across multiple conversations:
```typescript
class ChatStore {
chatLoadingStates = new Map<string, boolean>();
chatStreamingStates = new Map<string, { response: string; messageId: string }>();
abortControllers = new Map<string, AbortController>();
}
```
### 4. Message Branching with Tree Structure
Conversations are stored as a tree, not a linear list:
```typescript
interface DatabaseMessage {
id: string;
parent: string | null; // Points to parent message
children: string[]; // List of child message IDs
// ...
}
interface DatabaseConversation {
currentNode: string; // Currently viewed branch tip
// ...
}
```
Navigation between branches updates `currentNode` without losing history.
### 5. Layered Service Architecture
Stores handle state; services handle I/O:
```text
┌─────────────────┐
│ Stores │ Business logic, state management
├─────────────────┤
│ Services │ API calls, database operations
├─────────────────┤
│ Storage/API │ IndexedDB, LocalStorage, HTTP
└─────────────────┘
```
### 6. Server Role Abstraction
Single codebase handles both MODEL and ROUTER modes:
```typescript
// serverStore.ts
get isRouterMode() {
return this.role === ServerRole.ROUTER;
}
// Components conditionally render based on mode
{#if isRouterMode()}
<ModelsSelector />
{/if}
```
### 7. Modality Validation
Prevents sending attachments to incompatible models:
```typescript
// useModelChangeValidation hook
const validate = (modelId: string) => {
const modelModalities = modelsStore.getModelModalities(modelId);
const conversationModalities = conversationsStore.usedModalities;
// Check if model supports all used modalities
if (conversationModalities.hasImages && !modelModalities.vision) {
return { valid: false, reason: 'Model does not support images' };
}
// ...
};
```
### 8. Persistent Storage Strategy
Data is persisted across sessions using two storage mechanisms:
```mermaid
flowchart TB
subgraph Browser["Browser Storage"]
subgraph IDB["IndexedDB (Dexie)"]
C[Conversations]
M[Messages]
end
subgraph LS["LocalStorage"]
S[Settings Config]
O[User Overrides]
T[Theme Preference]
end
end
subgraph Stores["Svelte Stores"]
CS[conversationsStore] --> C
CS --> M
SS[settingsStore] --> S
SS --> O
SS --> T
end
```
- **IndexedDB**: Conversations and messages (large, structured data)
- **LocalStorage**: Settings, user parameter overrides, theme (small key-value data)
- **Memory only**: Server props, model list (fetched fresh on each session)
---
## Testing
### Test Types
| Type | Tool | Location | Command |
| ------------- | ------------------ | ---------------- | ------------------- |
| **Unit** | Vitest | `tests/unit/` | `npm run test:unit` |
| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` |
| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` |
| **Client** | Vitest | `tests/client/`. | `npm run test:unit` |
### Running Tests
```bash
# All tests
npm run test
# Individual test suites
npm run test:e2e # End-to-end (requires llama-server)
npm run test:client # Client-side unit tests
npm run test:server # Server-side unit tests
npm run test:ui # Storybook visual tests
```
### Storybook Development
```bash
npm run storybook # Start Storybook dev server on :6006
npm run build-storybook # Build static Storybook
```
### Linting and Formatting
```bash
npm run lint # Check code style
npm run format # Auto-format with Prettier
npm run check # TypeScript type checking
```
---
## Project Structure
```text
tools/server/webui/
├── src/
│ ├── lib/
│ │ ├── components/ # UI components (app/, ui/)
│ │ ├── hooks/ # Svelte hooks
│ │ ├── stores/ # State management
│ │ ├── services/ # API and database services
│ │ ├── types/ # TypeScript interfaces
│ │ └── utils/ # Utility functions
│ ├── routes/ # SvelteKit routes
│ └── styles/ # Global styles
├── static/ # Static assets
├── tests/ # Test files
├── docs/ # Architecture diagrams
│ ├── architecture/ # High-level architecture
│ └── flows/ # Feature-specific flows
└── .storybook/ # Storybook configuration
```
---
## Related Documentation
- [llama.cpp Server README](../README.md) - Full server documentation
- [Multimodal Documentation](../../../docs/multimodal.md) - Image and audio support
- [Function Calling](../../../docs/function-calling.md) - Tool use capabilities

View File

@@ -0,0 +1,16 @@
{
"$schema": "https://shadcn-svelte.com/schema.json",
"tailwind": {
"css": "src/app.css",
"baseColor": "neutral"
},
"aliases": {
"components": "$lib/components",
"utils": "$lib/components/ui/utils",
"ui": "$lib/components/ui",
"hooks": "$lib/hooks",
"lib": "$lib"
},
"typescript": true,
"registry": "https://shadcn-svelte.com/registry"
}

View File

@@ -0,0 +1,145 @@
```mermaid
flowchart TB
subgraph Routes["📍 Routes"]
R1["/ (Welcome)"]
R2["/chat/[id]"]
RL["+layout.svelte"]
end
subgraph Components["🧩 Components"]
C_Sidebar["ChatSidebar"]
C_Screen["ChatScreen"]
C_Form["ChatForm"]
C_Messages["ChatMessages"]
C_Message["ChatMessage"]
C_ChatMessageAgenticContent["ChatMessageAgenticContent"]
C_MessageEditForm["ChatMessageEditForm"]
C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"]
C_McpSettings["McpServersSettings"]
C_McpResourceBrowser["McpResourceBrowser"]
C_McpServersSelector["McpServersSelector"]
end
subgraph Hooks["🪝 Hooks"]
H1["useModelChangeValidation"]
H2["useProcessingState"]
end
subgraph Stores["🗄️ Stores"]
S1["chatStore<br/><i>Chat interactions & streaming</i>"]
SA["agenticStore<br/><i>Multi-turn agentic loop orchestration</i>"]
S2["conversationsStore<br/><i>Conversation data, messages & MCP overrides</i>"]
S3["modelsStore<br/><i>Model selection & loading</i>"]
S4["serverStore<br/><i>Server props & role detection</i>"]
S5["settingsStore<br/><i>User configuration incl. MCP</i>"]
S6["mcpStore<br/><i>MCP servers, tools, prompts</i>"]
S7["mcpResourceStore<br/><i>MCP resources & attachments</i>"]
end
subgraph Services["⚙️ Services"]
SV1["ChatService"]
SV2["ModelsService"]
SV3["PropsService"]
SV4["DatabaseService"]
SV5["ParameterSyncService"]
SV6["MCPService<br/><i>protocol operations</i>"]
end
subgraph Storage["💾 Storage"]
ST1["IndexedDB<br/><i>conversations, messages</i>"]
ST2["LocalStorage<br/><i>config, userOverrides, mcpServers</i>"]
end
subgraph APIs["🌐 llama-server API"]
API1["/v1/chat/completions"]
API2["/props"]
API3["/models/*"]
API4["/v1/models"]
end
subgraph ExternalMCP["🔌 External MCP Servers"]
EXT1["MCP Server 1<br/><i>WebSocket/HTTP/SSE</i>"]
EXT2["MCP Server N"]
end
%% Routes → Components
R1 & R2 --> C_Screen
RL --> C_Sidebar
%% Layout runs MCP health checks
RL --> S6
%% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings
C_Messages --> C_Message
C_Message --> C_ChatMessageAgenticContent
C_Message --> C_MessageEditForm
C_Form & C_MessageEditForm --> C_ModelsSelector
C_Form --> C_McpServersSelector
C_Settings --> C_McpSettings
C_McpSettings --> C_McpResourceBrowser
%% Components → Hooks → Stores
C_Form & C_Messages --> H1 & H2
H1 --> S3 & S4
H2 --> S1 & S5
%% Components → Stores
C_Screen --> S1 & S2
C_Sidebar --> S2
C_ModelsSelector --> S3 & S4
C_Settings --> S5
C_McpSettings --> S6
C_McpResourceBrowser --> S6 & S7
C_McpServersSelector --> S6
C_Form --> S6
%% chatStore → agenticStore → mcpStore (agentic loop)
S1 --> SA
SA --> SV1
SA --> S6
%% Stores → Services
S1 --> SV1 & SV4
S2 --> SV4
S3 --> SV2 & SV3
S4 --> SV3
S5 --> SV5
S6 --> SV6
S7 --> SV6
%% Services → Storage
SV4 --> ST1
SV5 --> ST2
%% Services → APIs
SV1 --> API1
SV2 --> API3 & API4
SV3 --> API2
%% MCP → External Servers
SV6 --> EXT1 & EXT2
%% Styling
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px
classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
classDef mcpStyle fill:#e0f2f1,stroke:#00695c,stroke-width:2px
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_ChatMessageAgenticContent,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
class C_McpSettings,C_McpResourceBrowser,C_McpServersSelector componentStyle
class H1,H2 hookStyle
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
class ST1,ST2 storageStyle
class API1,API2,API3,API4 apiStyle
class EXT1,EXT2 externalStyle
```

View File

@@ -0,0 +1,373 @@
```mermaid
flowchart TB
subgraph Routes["📍 Routes"]
R1["/ (+page.svelte)"]
R2["/chat/[id]"]
RL["+layout.svelte"]
end
subgraph Components["🧩 Components"]
direction TB
subgraph LayoutComponents["Layout"]
C_Sidebar["ChatSidebar"]
C_Screen["ChatScreen"]
end
subgraph ChatUIComponents["Chat UI"]
C_Form["ChatForm"]
C_Messages["ChatMessages"]
C_Message["ChatMessage"]
C_MessageUser["ChatMessageUser"]
C_MessageEditForm["ChatMessageEditForm"]
C_Attach["ChatAttachments"]
C_ModelsSelector["ModelsSelector"]
C_Settings["ChatSettings"]
end
subgraph MCPComponents["MCP UI"]
C_McpSettings["McpServersSettings"]
C_McpServerCard["McpServerCard"]
C_McpResourceBrowser["McpResourceBrowser"]
C_McpResourcePreview["McpResourcePreview"]
C_McpServersSelector["McpServersSelector"]
end
end
subgraph Hooks["🪝 Hooks"]
H1["useModelChangeValidation"]
H2["useProcessingState"]
H3["isMobile"]
end
subgraph Stores["🗄️ Stores"]
direction TB
subgraph S1["chatStore"]
S1State["<b>State:</b><br/>isLoading, currentResponse<br/>errorDialogState<br/>activeProcessingState<br/>chatLoadingStates<br/>chatStreamingStates<br/>abortControllers<br/>processingStates<br/>activeConversationId<br/>isStreamingActive"]
S1LoadState["<b>Loading State:</b><br/>setChatLoading()<br/>isChatLoading()<br/>syncLoadingStateForChat()<br/>clearUIState()<br/>isChatLoadingPublic()<br/>getAllLoadingChats()<br/>getAllStreamingChats()"]
S1ProcState["<b>Processing State:</b><br/>setActiveProcessingConversation()<br/>getProcessingState()<br/>clearProcessingState()<br/>getActiveProcessingState()<br/>updateProcessingStateFromTimings()<br/>getCurrentProcessingStateSync()<br/>restoreProcessingStateFromMessages()"]
S1Stream["<b>Streaming:</b><br/>streamChatCompletion()<br/>startStreaming()<br/>stopStreaming()<br/>stopGeneration()<br/>isStreaming()"]
S1Error["<b>Error Handling:</b><br/>showErrorDialog()<br/>dismissErrorDialog()<br/>isAbortError()"]
S1Msg["<b>Message Operations:</b><br/>addMessage()<br/>sendMessage()<br/>updateMessage()<br/>deleteMessage()<br/>getDeletionInfo()"]
S1Regen["<b>Regeneration:</b><br/>regenerateMessage()<br/>regenerateMessageWithBranching()<br/>continueAssistantMessage()"]
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()<br/>clearEditMode()<br/>isEditModeActive()<br/>getAddFilesHandler()<br/>setEditModeActive()"]
S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"]
end
subgraph SA["agenticStore"]
SAState["<b>State:</b><br/>sessions (Map)<br/>isAnyRunning"]
SASession["<b>Session Management:</b><br/>getSession()<br/>updateSession()<br/>clearSession()<br/>getActiveSessions()<br/>isRunning()<br/>currentTurn()<br/>totalToolCalls()<br/>lastError()<br/>streamingToolCall()"]
SAConfig["<b>Configuration:</b><br/>getConfig()<br/>maxTurns, maxToolPreviewLines"]
SAFlow["<b>Agentic Loop:</b><br/>runAgenticFlow()<br/>executeAgenticLoop()<br/>normalizeToolCalls()<br/>emitToolCallResult()<br/>extractBase64Attachments()"]
end
subgraph S2["conversationsStore"]
S2State["<b>State:</b><br/>conversations<br/>activeConversation<br/>activeMessages<br/>isInitialized<br/>pendingMcpServerOverrides<br/>titleUpdateConfirmationCallback"]
S2Lifecycle["<b>Lifecycle:</b><br/>initialize()<br/>loadConversations()<br/>clearActiveConversation()"]
S2ConvCRUD["<b>Conversation CRUD:</b><br/>createConversation()<br/>loadConversation()<br/>deleteConversation()<br/>deleteAll()<br/>updateConversationName()<br/>updateConversationTitleWithConfirmation()"]
S2MsgMgmt["<b>Message Management:</b><br/>refreshActiveMessages()<br/>addMessageToActive()<br/>updateMessageAtIndex()<br/>findMessageIndex()<br/>sliceActiveMessages()<br/>removeMessageAtIndex()<br/>getConversationMessages()"]
S2Nav["<b>Navigation:</b><br/>navigateToSibling()<br/>updateCurrentNode()<br/>updateConversationTimestamp()"]
S2McpOverrides["<b>MCP Per-Chat Overrides:</b><br/>getMcpServerOverride()<br/>getAllMcpServerOverrides()<br/>setMcpServerOverride()<br/>toggleMcpServerForChat()<br/>removeMcpServerOverride()<br/>isMcpServerEnabledForChat()<br/>clearPendingMcpServerOverrides()"]
S2Export["<b>Import/Export:</b><br/>downloadConversation()<br/>exportAllConversations()<br/>importConversations()<br/>importConversationsData()<br/>triggerDownload()"]
S2Utils["<b>Utilities:</b><br/>setTitleUpdateConfirmationCallback()"]
end
subgraph S3["modelsStore"]
S3State["<b>State:</b><br/>models, routerModels<br/>selectedModelId<br/>selectedModelName<br/>loading, updating, error<br/>modelLoadingStates<br/>modelPropsCache<br/>modelPropsFetching<br/>propsCacheVersion"]
S3Getters["<b>Computed Getters:</b><br/>selectedModel<br/>loadedModelIds<br/>loadingModelIds<br/>singleModelName"]
S3Modal["<b>Modalities:</b><br/>getModelModalities()<br/>modelSupportsVision()<br/>modelSupportsAudio()<br/>getModelModalitiesArray()<br/>getModelProps()<br/>updateModelModalities()"]
S3Status["<b>Status Queries:</b><br/>isModelLoaded()<br/>isModelOperationInProgress()<br/>getModelStatus()<br/>isModelPropsFetching()"]
S3Fetch["<b>Data Fetching:</b><br/>fetch()<br/>fetchRouterModels()<br/>fetchModelProps()<br/>fetchModalitiesForLoadedModels()"]
S3Select["<b>Model Selection:</b><br/>selectModelById()<br/>selectModelByName()<br/>clearSelection()<br/>findModelByName()<br/>findModelById()<br/>hasModel()"]
S3LoadUnload["<b>Loading/Unloading Models:</b><br/>loadModel()<br/>unloadModel()<br/>ensureModelLoaded()<br/>waitForModelStatus()<br/>pollForModelStatus()"]
S3Utils["<b>Utilities:</b><br/>toDisplayName()<br/>clear()"]
end
subgraph S4["serverStore"]
S4State["<b>State:</b><br/>props<br/>loading, error<br/>role<br/>fetchPromise"]
S4Getters["<b>Getters:</b><br/>defaultParams<br/>contextSize<br/>isRouterMode<br/>isModelMode"]
S4Data["<b>Data Handling:</b><br/>fetch()<br/>getErrorMessage()<br/>clear()"]
S4Utils["<b>Utilities:</b><br/>detectRole()"]
end
subgraph S5["settingsStore"]
S5State["<b>State:</b><br/>config<br/>theme<br/>isInitialized<br/>userOverrides"]
S5Lifecycle["<b>Lifecycle:</b><br/>initialize()<br/>loadConfig()<br/>saveConfig()<br/>loadTheme()<br/>saveTheme()"]
S5Update["<b>Config Updates:</b><br/>updateConfig()<br/>updateMultipleConfig()<br/>updateTheme()"]
S5Reset["<b>Reset:</b><br/>resetConfig()<br/>resetTheme()<br/>resetAll()<br/>resetParameterToServerDefault()"]
S5Sync["<b>Server Sync:</b><br/>syncWithServerDefaults()<br/>forceSyncWithServerDefaults()"]
S5Utils["<b>Utilities:</b><br/>getConfig()<br/>getAllConfig()<br/>getParameterInfo()<br/>getParameterDiff()<br/>getServerDefaults()<br/>clearAllUserOverrides()"]
end
subgraph S6["mcpStore"]
S6State["<b>State:</b><br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)"]
S6Lifecycle["<b>Lifecycle:</b><br/>ensureInitialized()<br/>initialize()<br/>shutdown()<br/>acquireConnection()<br/>releaseConnection()"]
S6Health["<b>Health Checks:</b><br/>runHealthCheck()<br/>runHealthChecksForServers()<br/>updateHealthCheck()<br/>getHealthCheckState()<br/>clearHealthCheck()"]
S6Servers["<b>Server Management:</b><br/>getServers()<br/>addServer()<br/>updateServer()<br/>removeServer()<br/>getServerById()<br/>getServerDisplayName()"]
S6Tools["<b>Tool Operations:</b><br/>getToolDefinitionsForLLM()<br/>getToolNames()<br/>hasTool()<br/>getToolServer()<br/>executeTool()<br/>executeToolByName()"]
S6Prompts["<b>Prompt Operations:</b><br/>getAllPrompts()<br/>getPrompt()<br/>hasPromptsCapability()<br/>getPromptCompletions()"]
end
subgraph S7["mcpResourceStore"]
S7State["<b>State:</b><br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]<br/>isLoading"]
S7Resources["<b>Resource Discovery:</b><br/>setServerResources()<br/>getServerResources()<br/>getAllResourceInfos()<br/>getAllTemplateInfos()<br/>clearServerResources()"]
S7Cache["<b>Caching:</b><br/>cacheResourceContent()<br/>getCachedContent()<br/>invalidateCache()<br/>clearCache()"]
S7Subs["<b>Subscriptions:</b><br/>addSubscription()<br/>removeSubscription()<br/>isSubscribed()<br/>handleResourceUpdate()"]
S7Attach["<b>Attachments:</b><br/>addAttachment()<br/>updateAttachmentContent()<br/>removeAttachment()<br/>clearAttachments()<br/>toMessageExtras()"]
end
subgraph ReactiveExports["⚡ Reactive Exports"]
direction LR
subgraph ChatExports["chatStore"]
RE1["isLoading()"]
RE2["currentResponse()"]
RE3["errorDialog()"]
RE4["activeProcessingState()"]
RE5["isChatStreaming()"]
RE6["isChatLoading()"]
RE7["getChatStreaming()"]
RE8["getAllLoadingChats()"]
RE9["getAllStreamingChats()"]
RE9a["isEditModeActive()"]
RE9b["getAddFilesHandler()"]
RE9c["setEditModeActive()"]
RE9d["clearEditMode()"]
end
subgraph AgenticExports["agenticStore"]
REA1["agenticIsRunning()"]
REA2["agenticCurrentTurn()"]
REA3["agenticTotalToolCalls()"]
REA4["agenticLastError()"]
REA5["agenticStreamingToolCall()"]
REA6["agenticIsAnyRunning()"]
end
subgraph ConvExports["conversationsStore"]
RE10["conversations()"]
RE11["activeConversation()"]
RE12["activeMessages()"]
RE13["isConversationsInitialized()"]
end
subgraph ModelsExports["modelsStore"]
RE15["modelOptions()"]
RE16["routerModels()"]
RE17["modelsLoading()"]
RE18["modelsUpdating()"]
RE19["modelsError()"]
RE20["selectedModelId()"]
RE21["selectedModelName()"]
RE22["selectedModelOption()"]
RE23["loadedModelIds()"]
RE24["loadingModelIds()"]
RE25["propsCacheVersion()"]
RE26["singleModelName()"]
end
subgraph ServerExports["serverStore"]
RE27["serverProps()"]
RE28["serverLoading()"]
RE29["serverError()"]
RE30["serverRole()"]
RE31["defaultParams()"]
RE32["contextSize()"]
RE33["isRouterMode()"]
RE34["isModelMode()"]
end
subgraph SettingsExports["settingsStore"]
RE35["config()"]
RE36["theme()"]
RE37["isInitialized()"]
end
subgraph MCPExports["mcpStore / mcpResourceStore"]
RE38["mcpResources()"]
RE39["mcpResourceAttachments()"]
RE40["mcpHasResourceAttachments()"]
RE41["mcpTotalResourceCount()"]
RE42["mcpResourcesLoading()"]
end
end
end
subgraph Services["⚙️ Services"]
direction TB
subgraph SV1["ChatService"]
SV1Msg["<b>Messaging:</b><br/>sendMessage()"]
SV1Stream["<b>Streaming:</b><br/>handleStreamResponse()<br/>handleNonStreamResponse()"]
SV1Convert["<b>Conversion:</b><br/>convertDbMessageToApiChatMessageData()<br/>mergeToolCallDeltas()"]
SV1Utils["<b>Utilities:</b><br/>stripReasoningContent()<br/>extractModelName()<br/>parseErrorResponse()"]
end
subgraph SV2["ModelsService"]
SV2List["<b>Listing:</b><br/>list()<br/>listRouter()"]
SV2LoadUnload["<b>Load/Unload:</b><br/>load()<br/>unload()"]
SV2Status["<b>Status:</b><br/>isModelLoaded()<br/>isModelLoading()"]
end
subgraph SV3["PropsService"]
SV3Fetch["<b>Fetching:</b><br/>fetch()<br/>fetchForModel()"]
end
subgraph SV4["DatabaseService"]
SV4Conv["<b>Conversations:</b><br/>createConversation()<br/>getConversation()<br/>getAllConversations()<br/>updateConversation()<br/>deleteConversation()"]
SV4Msg["<b>Messages:</b><br/>createMessageBranch()<br/>createRootMessage()<br/>createSystemMessage()<br/>getConversationMessages()<br/>updateMessage()<br/>deleteMessage()<br/>deleteMessageCascading()"]
SV4Node["<b>Navigation:</b><br/>updateCurrentNode()"]
SV4Import["<b>Import:</b><br/>importConversations()"]
end
subgraph SV5["ParameterSyncService"]
SV5Extract["<b>Extraction:</b><br/>extractServerDefaults()"]
SV5Merge["<b>Merging:</b><br/>mergeWithServerDefaults()"]
SV5Info["<b>Info:</b><br/>getParameterInfo()<br/>canSyncParameter()<br/>getSyncableParameterKeys()<br/>validateServerParameter()"]
SV5Diff["<b>Diff:</b><br/>createParameterDiff()"]
end
subgraph SV6["MCPService"]
SV6Transport["<b>Transport:</b><br/>createTransport()<br/>WebSocket / StreamableHTTP / SSE"]
SV6Conn["<b>Connection:</b><br/>connect()<br/>disconnect()"]
SV6Tools["<b>Tools:</b><br/>listTools()<br/>callTool()"]
SV6Prompts["<b>Prompts:</b><br/>listPrompts()<br/>getPrompt()"]
SV6Resources["<b>Resources:</b><br/>listResources()<br/>listResourceTemplates()<br/>readResource()<br/>subscribeResource()<br/>unsubscribeResource()"]
SV6Complete["<b>Completions:</b><br/>complete()"]
end
end
subgraph ExternalMCP["🔌 External MCP Servers"]
EXT1["MCP Server 1<br/>(WebSocket/StreamableHTTP/SSE)"]
EXT2["MCP Server N"]
end
subgraph Storage["💾 Storage"]
ST1["IndexedDB"]
ST2["conversations"]
ST3["messages"]
ST5["LocalStorage"]
ST6["config"]
ST7["userOverrides"]
ST8["mcpServers"]
end
subgraph APIs["🌐 llama-server API"]
API1["/v1/chat/completions"]
API2["/props<br/>/props?model="]
API3["/models<br/>/models/load<br/>/models/unload"]
API4["/v1/models"]
end
%% Routes render Components
R1 --> C_Screen
R2 --> C_Screen
RL --> C_Sidebar
%% Layout runs MCP health checks on startup
RL --> S6
%% Component hierarchy
C_Screen --> C_Form & C_Messages & C_Settings
C_Messages --> C_Message
C_Message --> C_MessageUser
C_MessageUser --> C_MessageEditForm
C_MessageEditForm --> C_ModelsSelector
C_MessageEditForm --> C_Attach
C_Form --> C_ModelsSelector
C_Form --> C_Attach
C_Form --> C_McpServersSelector
C_Message --> C_Attach
%% MCP Components hierarchy
C_Settings --> C_McpSettings
C_McpSettings --> C_McpServerCard
C_McpServerCard --> C_McpResourceBrowser
C_McpResourceBrowser --> C_McpResourcePreview
%% Components use Hooks
C_Form --> H1
C_Message --> H1 & H2
C_MessageEditForm --> H1
C_Screen --> H2
%% Hooks use Stores
H1 --> S3 & S4
H2 --> S1 & S5
%% Components use Stores
C_Screen --> S1 & S2
C_Messages --> S2
C_Message --> S1 & S2 & S3
C_Form --> S1 & S3 & S6
C_Sidebar --> S2
C_ModelsSelector --> S3 & S4
C_Settings --> S5
C_McpSettings --> S6
C_McpServerCard --> S6
C_McpResourceBrowser --> S6 & S7
C_McpServersSelector --> S6
%% Stores export Reactive State
S1 -. exports .-> ChatExports
SA -. exports .-> AgenticExports
S2 -. exports .-> ConvExports
S3 -. exports .-> ModelsExports
S4 -. exports .-> ServerExports
S5 -. exports .-> SettingsExports
S6 -. exports .-> MCPExports
S7 -. exports .-> MCPExports
%% chatStore → agenticStore (agentic loop orchestration)
S1 --> SA
SA --> SV1
SA --> S6
%% Stores use Services
S1 --> SV1 & SV4
S2 --> SV4
S3 --> SV2 & SV3
S4 --> SV3
S5 --> SV5
S6 --> SV6
S7 --> SV6
%% Services to Storage
SV4 --> ST1
ST1 --> ST2 & ST3
SV5 --> ST5
ST5 --> ST6 & ST7 & ST8
%% Services to APIs
SV1 --> API1
SV2 --> API3 & API4
SV3 --> API2
%% MCP → External Servers
SV6 --> EXT1 & EXT2
%% Styling
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px
classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px
classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px
classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px
classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
class R1,R2,RL routeStyle
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle
class C_ModelsSelector,C_Settings componentStyle
class C_Attach componentStyle
class C_McpSettings,C_McpServerCard,C_McpResourceBrowser,C_McpResourcePreview,C_McpServersSelector componentStyle
class H1,H2,H3 hookStyle
class LayoutComponents,ChatUIComponents,MCPComponents componentGroupStyle
class Hooks hookStyle
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
classDef agenticMethodStyle fill:#c5cae9,stroke:#283593,stroke-width:1px
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
class S1State,S2State,S3State,S4State,S5State,SAState,S6State,S7State stateStyle
class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle
class SASession,SAConfig,SAFlow methodStyle
class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2McpOverrides,S2Export,S2Utils methodStyle
class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle
class S4Getters,S4Data,S4Utils methodStyle
class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle
class S6Lifecycle,S6Health,S6Servers,S6Tools,S6Prompts methodStyle
class S7Resources,S7Cache,S7Subs,S7Attach methodStyle
class ChatExports,AgenticExports,ConvExports,ModelsExports,ServerExports,SettingsExports,MCPExports reactiveStyle
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
class SV6Transport,SV6Conn,SV6Tools,SV6Prompts,SV6Resources,SV6Complete serviceMStyle
class EXT1,EXT2 externalStyle
class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle
class SV2List,SV2LoadUnload,SV2Status serviceMStyle
class SV3Fetch serviceMStyle
class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle
class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle
class ST1,ST2,ST3,ST5,ST6,ST7,ST8 storageStyle
class API1,API2,API3,API4 apiStyle
```

View File

@@ -0,0 +1,228 @@
```mermaid
sequenceDiagram
participant UI as 🧩 ChatForm / ChatMessage
participant chatStore as 🗄️ chatStore
participant agenticStore as 🗄️ agenticStore
participant convStore as 🗄️ conversationsStore
participant settingsStore as 🗄️ settingsStore
participant mcpStore as 🗄️ mcpStore
participant ChatSvc as ⚙️ ChatService
participant DbSvc as ⚙️ DatabaseService
participant API as 🌐 /v1/chat/completions
Note over chatStore: State:<br/>isLoading, currentResponse<br/>errorDialogState, activeProcessingState<br/>chatLoadingStates (Map)<br/>chatStreamingStates (Map)<br/>abortControllers (Map)<br/>processingStates (Map)
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 💬 SEND MESSAGE
%% ═══════════════════════════════════════════════════════════════════════════
UI->>chatStore: sendMessage(content, extras)
activate chatStore
chatStore->>chatStore: setChatLoading(convId, true)
chatStore->>chatStore: clearChatStreaming(convId)
alt no active conversation
chatStore->>convStore: createConversation()
Note over convStore: → see conversations-flow.mmd
end
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
Note right of mcpStore: Converts pending MCP resource<br/>attachments into message extras
chatStore->>chatStore: addMessage("user", content, extras)
chatStore->>DbSvc: createMessageBranch(userMsg, parentId)
chatStore->>convStore: addMessageToActive(userMsg)
chatStore->>convStore: updateCurrentNode(userMsg.id)
chatStore->>chatStore: createAssistantMessage(userMsg.id)
chatStore->>DbSvc: createMessageBranch(assistantMsg, userMsg.id)
chatStore->>convStore: addMessageToActive(assistantMsg)
chatStore->>chatStore: streamChatCompletion(messages, assistantMsg)
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🌊 STREAMING (with agentic flow detection)
%% ═══════════════════════════════════════════════════════════════════════════
activate chatStore
chatStore->>chatStore: startStreaming()
Note right of chatStore: isStreamingActive = true
chatStore->>chatStore: setActiveProcessingConversation(convId)
chatStore->>chatStore: getOrCreateAbortController(convId)
Note right of chatStore: abortControllers.set(convId, new AbortController())
chatStore->>chatStore: getApiOptions()
Note right of chatStore: Merge from settingsStore.config:<br/>temperature, max_tokens, top_p, etc.
alt agenticConfig.enabled && mcpStore has connected servers
chatStore->>agenticStore: runAgenticFlow(convId, messages, assistantMsg, options, signal)
Note over agenticStore: Multi-turn agentic loop:<br/>1. Call ChatService.sendMessage()<br/>2. If response has tool_calls → execute via mcpStore<br/>3. Append tool results as messages<br/>4. Loop until no more tool_calls or maxTurns<br/>→ see agentic flow details below
agenticStore-->>chatStore: final response with timings
else standard (non-agentic) flow
chatStore->>ChatSvc: sendMessage(messages, options, signal)
end
activate ChatSvc
ChatSvc->>ChatSvc: convertDbMessageToApiChatMessageData(messages)
Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]<br/>Process attachments (images, PDFs, audio)
ChatSvc->>API: POST /v1/chat/completions
Note right of API: {messages, model?, stream: true, ...params}
loop SSE chunks
API-->>ChatSvc: data: {"choices":[{"delta":{...}}]}
ChatSvc->>ChatSvc: handleStreamResponse(response)
alt content chunk
ChatSvc-->>chatStore: onChunk(content)
chatStore->>chatStore: setChatStreaming(convId, response, msgId)
Note right of chatStore: currentResponse = $state(accumulated)
chatStore->>convStore: updateMessageAtIndex(idx, {content})
end
alt reasoning chunk
ChatSvc-->>chatStore: onReasoningChunk(reasoning)
chatStore->>convStore: updateMessageAtIndex(idx, {thinking})
end
alt tool_calls chunk
ChatSvc-->>chatStore: onToolCallChunk(toolCalls)
chatStore->>convStore: updateMessageAtIndex(idx, {toolCalls})
end
alt model info
ChatSvc-->>chatStore: onModel(modelName)
chatStore->>chatStore: recordModel(modelName)
chatStore->>DbSvc: updateMessage(msgId, {model})
end
alt timings (during stream)
ChatSvc-->>chatStore: onTimings(timings, promptProgress)
chatStore->>chatStore: updateProcessingStateFromTimings()
end
chatStore-->>UI: reactive $state update
end
API-->>ChatSvc: data: [DONE]
ChatSvc-->>chatStore: onComplete(content, reasoning, timings, toolCalls)
deactivate ChatSvc
chatStore->>chatStore: stopStreaming()
chatStore->>DbSvc: updateMessage(msgId, {content, timings, model})
chatStore->>convStore: updateCurrentNode(msgId)
chatStore->>chatStore: setChatLoading(convId, false)
chatStore->>chatStore: clearChatStreaming(convId)
chatStore->>chatStore: clearProcessingState(convId)
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ⏹️ STOP GENERATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>chatStore: stopGeneration()
activate chatStore
chatStore->>chatStore: savePartialResponseIfNeeded(convId)
Note right of chatStore: Save currentResponse to DB if non-empty
chatStore->>chatStore: abortControllers.get(convId).abort()
Note right of chatStore: fetch throws AbortError → caught by isAbortError()
chatStore->>chatStore: stopStreaming()
chatStore->>chatStore: setChatLoading(convId, false)
chatStore->>chatStore: clearChatStreaming(convId)
chatStore->>chatStore: clearProcessingState(convId)
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🔁 REGENERATE
%% ═══════════════════════════════════════════════════════════════════════════
UI->>chatStore: regenerateMessageWithBranching(msgId, model?)
activate chatStore
chatStore->>convStore: findMessageIndex(msgId)
chatStore->>chatStore: Get parent of target message
chatStore->>chatStore: createAssistantMessage(parentId)
chatStore->>DbSvc: createMessageBranch(newAssistantMsg, parentId)
chatStore->>convStore: refreshActiveMessages()
Note right of chatStore: Same streaming flow
chatStore->>chatStore: streamChatCompletion(...)
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ➡️ CONTINUE
%% ═══════════════════════════════════════════════════════════════════════════
UI->>chatStore: continueAssistantMessage(msgId)
activate chatStore
chatStore->>chatStore: Get existing content from message
chatStore->>chatStore: streamChatCompletion(..., existingContent)
Note right of chatStore: Appends to existing message content
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ✏️ EDIT USER MESSAGE
%% ═══════════════════════════════════════════════════════════════════════════
UI->>chatStore: editMessageWithBranching(msgId, newContent, extras)
activate chatStore
chatStore->>chatStore: Get parent of target message
chatStore->>DbSvc: createMessageBranch(editedMsg, parentId)
chatStore->>convStore: refreshActiveMessages()
Note right of chatStore: Creates new branch, original preserved
chatStore->>chatStore: createAssistantMessage(editedMsg.id)
chatStore->>chatStore: streamChatCompletion(...)
Note right of chatStore: Automatically regenerates response
deactivate chatStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ❌ ERROR HANDLING
%% ═══════════════════════════════════════════════════════════════════════════
Note over chatStore: On stream error (non-abort):
chatStore->>chatStore: showErrorDialog(type, message)
Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message}
chatStore->>convStore: removeMessageAtIndex(failedMsgIdx)
chatStore->>DbSvc: deleteMessage(failedMsgId)
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🤖 AGENTIC LOOP (when agenticConfig.enabled)
%% ═══════════════════════════════════════════════════════════════════════════
Note over agenticStore: agenticStore.runAgenticFlow(convId, messages, assistantMsg, options, signal)
activate agenticStore
agenticStore->>agenticStore: getSession(convId) or create new
agenticStore->>agenticStore: updateSession(turn: 0, running: true)
loop executeAgenticLoop (until no tool_calls or maxTurns)
agenticStore->>agenticStore: turn++
agenticStore->>ChatSvc: sendMessage(messages, options, signal)
ChatSvc->>API: POST /v1/chat/completions
API-->>ChatSvc: response with potential tool_calls
ChatSvc-->>agenticStore: onComplete(content, reasoning, timings, toolCalls)
alt response has tool_calls
agenticStore->>agenticStore: normalizeToolCalls(toolCalls)
loop for each tool_call
agenticStore->>agenticStore: updateSession(streamingToolCall)
agenticStore->>mcpStore: executeTool(mcpCall, signal)
mcpStore-->>agenticStore: tool result
agenticStore->>agenticStore: extractBase64Attachments(result)
agenticStore->>agenticStore: emitToolCallResult(convId, ...)
agenticStore->>convStore: addMessageToActive(toolResultMsg)
agenticStore->>DbSvc: createMessageBranch(toolResultMsg)
end
agenticStore->>agenticStore: Create new assistantMsg for next turn
Note right of agenticStore: Continue loop with updated messages
else no tool_calls (final response)
agenticStore->>agenticStore: buildFinalTimings(allTurns)
Note right of agenticStore: Break loop, return final response
end
end
agenticStore->>agenticStore: updateSession(running: false)
agenticStore-->>chatStore: final content, timings, model
deactivate agenticStore
```

View File

@@ -0,0 +1,183 @@
```mermaid
sequenceDiagram
participant UI as 🧩 ChatSidebar / ChatScreen
participant convStore as 🗄️ conversationsStore
participant chatStore as 🗄️ chatStore
participant DbSvc as ⚙️ DatabaseService
participant IDB as 💾 IndexedDB
Note over convStore: State:<br/>conversations: DatabaseConversation[]<br/>activeConversation: DatabaseConversation | null<br/>activeMessages: DatabaseMessage[]<br/>isInitialized: boolean<br/>pendingMcpServerOverrides: Map&lt;string, McpServerOverride&gt;
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 🚀 INITIALIZATION
%% ═══════════════════════════════════════════════════════════════════════════
Note over convStore: Auto-initialized in constructor (browser only)
convStore->>convStore: initialize()
activate convStore
convStore->>convStore: loadConversations()
convStore->>DbSvc: getAllConversations()
DbSvc->>IDB: SELECT * FROM conversations ORDER BY lastModified DESC
IDB-->>DbSvc: Conversation[]
DbSvc-->>convStore: conversations
convStore->>convStore: conversations = $state(data)
convStore->>convStore: isInitialized = true
deactivate convStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: CREATE CONVERSATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: createConversation(name?)
activate convStore
convStore->>DbSvc: createConversation(name || "New Chat")
DbSvc->>IDB: INSERT INTO conversations
IDB-->>DbSvc: conversation {id, name, lastModified, currNode: ""}
DbSvc-->>convStore: conversation
convStore->>convStore: conversations.unshift(conversation)
convStore->>convStore: activeConversation = $state(conversation)
convStore->>convStore: activeMessages = $state([])
alt pendingMcpServerOverrides has entries
loop each pending override
convStore->>DbSvc: Store MCP server override for new conversation
end
convStore->>convStore: clearPendingMcpServerOverrides()
end
deactivate convStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 📂 LOAD CONVERSATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: loadConversation(convId)
activate convStore
convStore->>DbSvc: getConversation(convId)
DbSvc->>IDB: SELECT * FROM conversations WHERE id = ?
IDB-->>DbSvc: conversation
convStore->>convStore: activeConversation = $state(conversation)
convStore->>convStore: refreshActiveMessages()
convStore->>DbSvc: getConversationMessages(convId)
DbSvc->>IDB: SELECT * FROM messages WHERE convId = ?
IDB-->>DbSvc: allMessages[]
convStore->>convStore: filterByLeafNodeId(allMessages, currNode)
Note right of convStore: Filter to show only current branch path
convStore->>convStore: activeMessages = $state(filtered)
Note right of convStore: Route (+page.svelte) then calls:<br/>chatStore.syncLoadingStateForChat(convId)
deactivate convStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 🌳 MESSAGE BRANCHING MODEL
%% ═══════════════════════════════════════════════════════════════════════════
Note over IDB: Message Tree Structure:<br/>- Each message has parent (null for root)<br/>- Each message has children[] array<br/>- Conversation.currNode points to active leaf<br/>- filterByLeafNodeId() traverses from root to currNode
rect rgb(240, 240, 255)
Note over convStore: Example Branch Structure:
Note over convStore: root → user1 → assistant1 → user2 → assistant2a (currNode)<br/> ↘ assistant2b (alt branch)
end
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: ↔️ BRANCH NAVIGATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: navigateToSibling(msgId, direction)
activate convStore
convStore->>convStore: Find message in activeMessages
convStore->>convStore: Get parent message
convStore->>convStore: Find sibling in parent.children[]
convStore->>convStore: findLeafNode(siblingId, allMessages)
Note right of convStore: Navigate to leaf of sibling branch
convStore->>convStore: updateCurrentNode(leafId)
convStore->>DbSvc: updateCurrentNode(convId, leafId)
DbSvc->>IDB: UPDATE conversations SET currNode = ?
convStore->>convStore: refreshActiveMessages()
deactivate convStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 📝 UPDATE CONVERSATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: updateConversationName(convId, newName)
activate convStore
convStore->>DbSvc: updateConversation(convId, {name: newName})
DbSvc->>IDB: UPDATE conversations SET name = ?
convStore->>convStore: Update in conversations array
deactivate convStore
Note over convStore: Auto-title update (after first response):
convStore->>convStore: updateConversationTitleWithConfirmation()
convStore->>convStore: titleUpdateConfirmationCallback?()
Note right of convStore: Shows dialog if title would change
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 🗑️ DELETE CONVERSATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: deleteConversation(convId)
activate convStore
convStore->>DbSvc: deleteConversation(convId)
DbSvc->>IDB: DELETE FROM conversations WHERE id = ?
DbSvc->>IDB: DELETE FROM messages WHERE convId = ?
convStore->>convStore: conversations.filter(c => c.id !== convId)
alt deleted active conversation
convStore->>convStore: clearActiveConversation()
end
deactivate convStore
UI->>convStore: deleteAll()
activate convStore
convStore->>DbSvc: Delete all conversations and messages
convStore->>convStore: conversations = []
convStore->>convStore: clearActiveConversation()
deactivate convStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: <20> MCP SERVER PER-CHAT OVERRIDES
%% ═══════════════════════════════════════════════════════════════════════════
Note over convStore: Conversations can override which MCP servers are enabled.
Note over convStore: Uses pendingMcpServerOverrides before conversation<br/>is created, then persists to conversation metadata.
UI->>convStore: setMcpServerOverride(convId, serverName, override)
Note right of convStore: override = {enabled: boolean}
UI->>convStore: toggleMcpServerForChat(convId, serverName, enabled)
activate convStore
convStore->>convStore: setMcpServerOverride(convId, serverName, {enabled})
deactivate convStore
UI->>convStore: isMcpServerEnabledForChat(convId, serverName)
Note right of convStore: Check override → fall back to global MCP config
UI->>convStore: getAllMcpServerOverrides(convId)
Note right of convStore: Returns all overrides for a conversation
UI->>convStore: removeMcpServerOverride(convId, serverName)
UI->>convStore: getMcpServerOverride(convId, serverName)
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,IDB: 📤 EXPORT / 📥 IMPORT
%% ═══════════════════════════════════════════════════════════════════════════
UI->>convStore: exportAllConversations()
activate convStore
convStore->>DbSvc: getAllConversations()
loop each conversation
convStore->>DbSvc: getConversationMessages(convId)
end
convStore->>convStore: triggerDownload(JSON blob)
deactivate convStore
UI->>convStore: importConversations(file)
activate convStore
convStore->>convStore: Parse JSON file
convStore->>convStore: importConversationsData(parsed)
convStore->>DbSvc: importConversations(parsed)
Note right of DbSvc: Skips duplicate conversations<br/>(checks existing by ID)
DbSvc->>IDB: INSERT conversations + messages (skip existing)
convStore->>convStore: loadConversations()
deactivate convStore
```

View File

@@ -0,0 +1,45 @@
```mermaid
%% MODEL Mode Data Flow (single model)
%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd
sequenceDiagram
participant User as 👤 User
participant UI as 🧩 UI
participant Stores as 🗄️ Stores
participant DB as 💾 IndexedDB
participant API as 🌐 llama-server
Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd)
UI->>Stores: initialize()
Stores->>DB: load conversations
Stores->>API: GET /props
API-->>Stores: server config + modalities
Stores->>API: GET /v1/models
API-->>Stores: single model (auto-selected)
Note over User,API: 💬 Chat Flow (see: chat-flow.mmd)
User->>UI: send message
UI->>Stores: sendMessage()
Stores->>DB: save user message
Stores->>API: POST /v1/chat/completions (stream)
loop streaming
API-->>Stores: SSE chunks
Stores-->>UI: reactive update
end
API-->>Stores: done + timings
Stores->>DB: save assistant message
Note over User,API: 🔁 Regenerate
User->>UI: regenerate
Stores->>DB: create message branch
Note right of Stores: same streaming flow
Note over User,API: ⏹️ Stop
User->>UI: stop
Stores->>Stores: abort stream
Stores->>DB: save partial response
```

View File

@@ -0,0 +1,77 @@
```mermaid
%% ROUTER Mode Data Flow (multi-model)
%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd
sequenceDiagram
participant User as 👤 User
participant UI as 🧩 UI
participant Stores as 🗄️ Stores
participant DB as 💾 IndexedDB
participant API as 🌐 llama-server
Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd)
UI->>Stores: initialize()
Stores->>DB: load conversations
Stores->>API: GET /props
API-->>Stores: {role: "router"}
Stores->>API: GET /v1/models
API-->>Stores: models[] with status (loaded/available)
loop each loaded model
Stores->>API: GET /props?model=X
API-->>Stores: modalities (vision/audio)
end
Note over User,API: 🔄 Model Selection (see: models-flow.mmd)
User->>UI: select model
alt model not loaded
Stores->>API: POST /models/load
loop poll status
Stores->>API: GET /v1/models
API-->>Stores: check if loaded
end
Stores->>API: GET /props?model=X
API-->>Stores: cache modalities
end
Stores->>Stores: validate modalities vs conversation
alt valid
Stores->>Stores: select model
else invalid
Stores->>API: POST /models/unload
UI->>User: show error toast
end
Note over User,API: 💬 Chat Flow (see: chat-flow.mmd)
User->>UI: send message
UI->>Stores: sendMessage()
Stores->>DB: save user message
Stores->>API: POST /v1/chat/completions {model: X}
Note right of API: router forwards to model
loop streaming
API-->>Stores: SSE chunks + model info
Stores-->>UI: reactive update
end
API-->>Stores: done + timings
Stores->>DB: save assistant message + model used
Note over User,API: 🔁 Regenerate (optional: different model)
User->>UI: regenerate
Stores->>Stores: validate modalities up to this message
Stores->>DB: create message branch
Note right of Stores: same streaming flow
Note over User,API: ⏹️ Stop
User->>UI: stop
Stores->>Stores: abort stream
Stores->>DB: save partial response
Note over User,API: 🗑️ LRU Unloading
Note right of API: Server auto-unloads LRU models<br/>when cache full
User->>UI: select unloaded model
Note right of Stores: triggers load flow again
```

View File

@@ -0,0 +1,174 @@
```mermaid
sequenceDiagram
participant Store as 🗄️ Stores
participant DbSvc as ⚙️ DatabaseService
participant Dexie as 📦 Dexie ORM
participant IDB as 💾 IndexedDB
Note over DbSvc: Stateless service - all methods static<br/>Database: "LlamacppWebui"
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 📊 SCHEMA
%% ═══════════════════════════════════════════════════════════════════════════
rect rgb(240, 248, 255)
Note over IDB: conversations table:<br/>id (PK), lastModified, currNode, name
end
rect rgb(255, 248, 240)
Note over IDB: messages table:<br/>id (PK), convId (FK), type, role, timestamp,<br/>parent, children[], content, thinking,<br/>toolCalls, extra[], model, timings
end
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 💬 CONVERSATIONS CRUD
%% ═══════════════════════════════════════════════════════════════════════════
Store->>DbSvc: createConversation(name)
activate DbSvc
DbSvc->>DbSvc: Generate UUID
DbSvc->>Dexie: db.conversations.add({id, name, lastModified, currNode: ""})
Dexie->>IDB: INSERT
IDB-->>Dexie: success
DbSvc-->>Store: DatabaseConversation
deactivate DbSvc
Store->>DbSvc: getConversation(convId)
DbSvc->>Dexie: db.conversations.get(convId)
Dexie->>IDB: SELECT WHERE id = ?
IDB-->>DbSvc: DatabaseConversation
Store->>DbSvc: getAllConversations()
DbSvc->>Dexie: db.conversations.orderBy('lastModified').reverse().toArray()
Dexie->>IDB: SELECT ORDER BY lastModified DESC
IDB-->>DbSvc: DatabaseConversation[]
Store->>DbSvc: updateConversation(convId, updates)
DbSvc->>Dexie: db.conversations.update(convId, {...updates, lastModified})
Dexie->>IDB: UPDATE
Store->>DbSvc: deleteConversation(convId)
activate DbSvc
DbSvc->>Dexie: db.conversations.delete(convId)
Dexie->>IDB: DELETE FROM conversations
DbSvc->>Dexie: db.messages.where('convId').equals(convId).delete()
Dexie->>IDB: DELETE FROM messages WHERE convId = ?
deactivate DbSvc
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 📝 MESSAGES CRUD
%% ═══════════════════════════════════════════════════════════════════════════
Store->>DbSvc: createRootMessage(convId)
activate DbSvc
DbSvc->>DbSvc: Create root message {type: "root", parent: null}
DbSvc->>Dexie: db.messages.add(rootMsg)
Dexie->>IDB: INSERT
DbSvc-->>Store: rootMessageId
deactivate DbSvc
Store->>DbSvc: createSystemMessage(convId, content, parentId)
activate DbSvc
DbSvc->>DbSvc: Create message {role: "system", parent: parentId}
DbSvc->>Dexie: db.messages.add(systemMsg)
Dexie->>IDB: INSERT
DbSvc-->>Store: DatabaseMessage
deactivate DbSvc
Store->>DbSvc: createMessageBranch(message, parentId)
activate DbSvc
DbSvc->>DbSvc: Generate UUID for new message
DbSvc->>Dexie: db.messages.add({...message, id, parent: parentId})
Dexie->>IDB: INSERT message
alt parentId exists
DbSvc->>Dexie: db.messages.get(parentId)
Dexie->>IDB: SELECT parent
DbSvc->>DbSvc: parent.children.push(newId)
DbSvc->>Dexie: db.messages.update(parentId, {children})
Dexie->>IDB: UPDATE parent.children
end
DbSvc->>Dexie: db.conversations.update(convId, {currNode: newId})
Dexie->>IDB: UPDATE conversation.currNode
DbSvc-->>Store: DatabaseMessage
deactivate DbSvc
Store->>DbSvc: getConversationMessages(convId)
DbSvc->>Dexie: db.messages.where('convId').equals(convId).toArray()
Dexie->>IDB: SELECT WHERE convId = ?
IDB-->>DbSvc: DatabaseMessage[]
Store->>DbSvc: updateMessage(msgId, updates)
DbSvc->>Dexie: db.messages.update(msgId, updates)
Dexie->>IDB: UPDATE
Store->>DbSvc: deleteMessage(msgId)
DbSvc->>Dexie: db.messages.delete(msgId)
Dexie->>IDB: DELETE
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 🌳 BRANCHING OPERATIONS
%% ═══════════════════════════════════════════════════════════════════════════
Store->>DbSvc: updateCurrentNode(convId, nodeId)
DbSvc->>Dexie: db.conversations.update(convId, {currNode: nodeId, lastModified})
Dexie->>IDB: UPDATE
Store->>DbSvc: deleteMessageCascading(msgId)
activate DbSvc
DbSvc->>DbSvc: findDescendantMessages(msgId, allMessages)
Note right of DbSvc: Recursively find all children
loop each descendant
DbSvc->>Dexie: db.messages.delete(descendantId)
Dexie->>IDB: DELETE
end
DbSvc->>Dexie: db.messages.delete(msgId)
Dexie->>IDB: DELETE target message
alt target message has a parent
DbSvc->>Dexie: db.messages.get(parentId)
DbSvc->>DbSvc: parent.children.filter(id !== msgId)
DbSvc->>Dexie: db.messages.update(parentId, {children})
Note right of DbSvc: Remove deleted message from parent's children[]
end
deactivate DbSvc
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 📥 IMPORT
%% ═══════════════════════════════════════════════════════════════════════════
Store->>DbSvc: importConversations(data)
activate DbSvc
loop each conversation in data
DbSvc->>Dexie: db.conversations.get(conv.id)
alt conversation already exists
Note right of DbSvc: Skip duplicate (keep existing)
else conversation is new
DbSvc->>Dexie: db.conversations.add(conversation)
Dexie->>IDB: INSERT conversation
loop each message
DbSvc->>Dexie: db.messages.add(message)
Dexie->>IDB: INSERT message
end
end
end
deactivate DbSvc
%% ═══════════════════════════════════════════════════════════════════════════
Note over Store,IDB: 🔗 MESSAGE TREE UTILITIES
%% ═══════════════════════════════════════════════════════════════════════════
Note over DbSvc: Used by stores (imported from utils):
rect rgb(240, 255, 240)
Note over DbSvc: filterByLeafNodeId(messages, leafId)<br/>→ Returns path from root to leaf<br/>→ Used to display current branch
end
rect rgb(240, 255, 240)
Note over DbSvc: findLeafNode(startId, messages)<br/>→ Traverse to deepest child<br/>→ Used for branch navigation
end
rect rgb(240, 255, 240)
Note over DbSvc: findDescendantMessages(msgId, messages)<br/>→ Find all children recursively<br/>→ Used for cascading deletes
end
```

View File

@@ -0,0 +1,226 @@
```mermaid
sequenceDiagram
participant UI as 🧩 McpServersSettings / ChatForm
participant chatStore as 🗄️ chatStore
participant mcpStore as 🗄️ mcpStore
participant mcpResStore as 🗄️ mcpResourceStore
participant convStore as 🗄️ conversationsStore
participant MCPSvc as ⚙️ MCPService
participant LS as 💾 LocalStorage
participant ExtMCP as 🔌 External MCP Server
Note over mcpStore: State:<br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)<br/>serverConfigs (Map)
Note over mcpResStore: State:<br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: 🚀 INITIALIZATION (App Startup)
%% ═══════════════════════════════════════════════════════════════════════════
UI->>mcpStore: ensureInitialized()
activate mcpStore
mcpStore->>LS: get(MCP_SERVERS_LOCALSTORAGE_KEY)
LS-->>mcpStore: MCPServerSettingsEntry[]
mcpStore->>mcpStore: parseServerSettings(servers)
Note right of mcpStore: Filter enabled servers<br/>Build MCPServerConfig objects<br/>Per-chat overrides checked via convStore
loop For each enabled server
mcpStore->>mcpStore: runHealthCheck(serverId)
mcpStore->>mcpStore: updateHealthCheck(id, CONNECTING)
mcpStore->>MCPSvc: connect(serverName, config, clientInfo, capabilities, onPhase)
activate MCPSvc
MCPSvc->>MCPSvc: createTransport(config)
Note right of MCPSvc: WebSocket / StreamableHTTP / SSE<br/>with optional CORS proxy
MCPSvc->>ExtMCP: Transport handshake
ExtMCP-->>MCPSvc: Connection established
MCPSvc->>ExtMCP: Initialize request
Note right of ExtMCP: Exchange capabilities<br/>Server info, protocol version
ExtMCP-->>MCPSvc: InitializeResult (serverInfo, capabilities)
MCPSvc->>ExtMCP: listTools()
ExtMCP-->>MCPSvc: Tool[]
MCPSvc-->>mcpStore: MCPConnection
deactivate MCPSvc
mcpStore->>mcpStore: connections.set(serverName, connection)
mcpStore->>mcpStore: indexTools(connection.tools, serverName)
Note right of mcpStore: toolsIndex.set(toolName, serverName)<br/>Handle name conflicts with prefixes
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
mcpStore->>mcpStore: _connectedServers.push(serverName)
alt Server supports resources
mcpStore->>MCPSvc: listAllResources(connection)
MCPSvc->>ExtMCP: listResources()
ExtMCP-->>MCPSvc: MCPResource[]
MCPSvc-->>mcpStore: resources
mcpStore->>MCPSvc: listAllResourceTemplates(connection)
MCPSvc->>ExtMCP: listResourceTemplates()
ExtMCP-->>MCPSvc: MCPResourceTemplate[]
MCPSvc-->>mcpStore: templates
mcpStore->>mcpResStore: setServerResources(serverName, resources, templates)
end
end
mcpStore->>mcpStore: _isInitializing = false
deactivate mcpStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: 🔧 TOOL EXECUTION (Chat with Tools)
%% ═══════════════════════════════════════════════════════════════════════════
UI->>mcpStore: executeTool(mcpCall: MCPToolCall, signal?)
activate mcpStore
mcpStore->>mcpStore: toolsIndex.get(mcpCall.function.name)
Note right of mcpStore: Resolve serverName from toolsIndex<br/>MCPToolCall = {id, type, function: {name, arguments}}
mcpStore->>mcpStore: acquireConnection()
Note right of mcpStore: activeFlowCount++<br/>Prevent shutdown during execution
mcpStore->>mcpStore: connection = connections.get(serverName)
mcpStore->>MCPSvc: callTool(connection, {name, arguments}, signal)
activate MCPSvc
MCPSvc->>MCPSvc: throwIfAborted(signal)
MCPSvc->>ExtMCP: callTool(name, arguments)
alt Tool execution success
ExtMCP-->>MCPSvc: ToolCallResult (content, isError)
MCPSvc->>MCPSvc: formatToolResult(result)
Note right of MCPSvc: Handle text, image (base64),<br/>embedded resource content
MCPSvc-->>mcpStore: ToolExecutionResult
else Tool execution error
ExtMCP-->>MCPSvc: Error
MCPSvc-->>mcpStore: throw Error
else Aborted
MCPSvc-->>mcpStore: throw AbortError
end
deactivate MCPSvc
mcpStore->>mcpStore: releaseConnection()
Note right of mcpStore: activeFlowCount--
mcpStore-->>UI: ToolExecutionResult
deactivate mcpStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: <20> RESOURCE ATTACHMENT CONSUMPTION
%% ═══════════════════════════════════════════════════════════════════════════
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
activate mcpStore
mcpStore->>mcpResStore: getAttachments()
mcpResStore-->>mcpStore: MCPResourceAttachment[]
mcpStore->>mcpStore: Convert attachments to message extras
mcpStore->>mcpResStore: clearAttachments()
mcpStore-->>chatStore: MessageExtra[] (for user message)
deactivate mcpStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: <20>📝 PROMPT OPERATIONS
%% ═══════════════════════════════════════════════════════════════════════════
UI->>mcpStore: getAllPrompts()
activate mcpStore
loop For each connected server with prompts capability
mcpStore->>MCPSvc: listPrompts(connection)
MCPSvc->>ExtMCP: listPrompts()
ExtMCP-->>MCPSvc: Prompt[]
MCPSvc-->>mcpStore: prompts
end
mcpStore-->>UI: MCPPromptInfo[] (with serverName)
deactivate mcpStore
UI->>mcpStore: getPrompt(serverName, promptName, args?)
activate mcpStore
mcpStore->>MCPSvc: getPrompt(connection, name, args)
MCPSvc->>ExtMCP: getPrompt({name, arguments})
ExtMCP-->>MCPSvc: GetPromptResult (messages)
MCPSvc-->>mcpStore: GetPromptResult
mcpStore-->>UI: GetPromptResult
deactivate mcpStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: 📁 RESOURCE OPERATIONS
%% ═══════════════════════════════════════════════════════════════════════════
UI->>mcpResStore: addAttachment(resourceInfo)
activate mcpResStore
mcpResStore->>mcpResStore: Create MCPResourceAttachment (loading: true)
mcpResStore-->>UI: attachment
UI->>mcpStore: readResource(serverName, uri)
activate mcpStore
mcpStore->>MCPSvc: readResource(connection, uri)
MCPSvc->>ExtMCP: readResource({uri})
ExtMCP-->>MCPSvc: MCPReadResourceResult (contents)
MCPSvc-->>mcpStore: contents
mcpStore-->>UI: MCPResourceContent[]
deactivate mcpStore
UI->>mcpResStore: updateAttachmentContent(attachmentId, content)
mcpResStore->>mcpResStore: cacheResourceContent(resource, content)
deactivate mcpResStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: 🔄 AUTO-RECONNECTION
%% ═══════════════════════════════════════════════════════════════════════════
Note over mcpStore: On WebSocket close or connection error:
mcpStore->>mcpStore: autoReconnect(serverName, attempt)
activate mcpStore
mcpStore->>mcpStore: Calculate backoff delay
Note right of mcpStore: delay = min(30s, 1s * 2^attempt)
mcpStore->>mcpStore: Wait for delay
mcpStore->>mcpStore: reconnectServer(serverName)
alt Reconnection success
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
else Max attempts reached
mcpStore->>mcpStore: updateHealthCheck(id, ERROR)
end
deactivate mcpStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,ExtMCP: 🛑 SHUTDOWN
%% ═══════════════════════════════════════════════════════════════════════════
UI->>mcpStore: shutdown()
activate mcpStore
mcpStore->>mcpStore: Wait for activeFlowCount == 0
loop For each connection
mcpStore->>MCPSvc: disconnect(connection)
MCPSvc->>MCPSvc: transport.onclose = undefined
MCPSvc->>ExtMCP: close()
end
mcpStore->>mcpStore: connections.clear()
mcpStore->>mcpStore: toolsIndex.clear()
mcpStore->>mcpStore: _connectedServers = []
mcpStore->>mcpResStore: clear()
deactivate mcpStore
```

View File

@@ -0,0 +1,181 @@
```mermaid
sequenceDiagram
participant UI as 🧩 ModelsSelector
participant Hooks as 🪝 useModelChangeValidation
participant modelsStore as 🗄️ modelsStore
participant serverStore as 🗄️ serverStore
participant convStore as 🗄️ conversationsStore
participant ModelsSvc as ⚙️ ModelsService
participant PropsSvc as ⚙️ PropsService
participant API as 🌐 llama-server
Note over modelsStore: State:<br/>models: ModelOption[]<br/>routerModels: ApiModelDataEntry[]<br/>selectedModelId, selectedModelName<br/>loading, updating, error<br/>modelLoadingStates (Map)<br/>modelPropsCache (Map)<br/>propsCacheVersion
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🚀 INITIALIZATION (MODEL mode)
%% ═══════════════════════════════════════════════════════════════════════════
UI->>modelsStore: fetch()
activate modelsStore
modelsStore->>modelsStore: loading = true
alt serverStore.props not loaded
modelsStore->>serverStore: fetch()
Note over serverStore: → see server-flow.mmd
end
modelsStore->>ModelsSvc: list()
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: ApiModelListResponse {data: [model]}
modelsStore->>modelsStore: models = $state(mapped)
Note right of modelsStore: Map to ModelOption[]:<br/>{id, name, model, description, capabilities}
Note over modelsStore: MODEL mode: Get modalities from serverStore.props
modelsStore->>modelsStore: modelPropsCache.set(model.id, serverStore.props)
modelsStore->>modelsStore: models[0].modalities = props.modalities
modelsStore->>modelsStore: Auto-select single model
Note right of modelsStore: selectedModelId = models[0].id
modelsStore->>modelsStore: loading = false
deactivate modelsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🚀 INITIALIZATION (ROUTER mode)
%% ═══════════════════════════════════════════════════════════════════════════
UI->>modelsStore: fetch()
activate modelsStore
modelsStore->>ModelsSvc: list()
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: ApiModelListResponse
modelsStore->>modelsStore: models = $state(mapped)
deactivate modelsStore
Note over UI: After models loaded, layout triggers:
UI->>modelsStore: fetchRouterModels()
activate modelsStore
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: ApiRouterModelsListResponse
Note right of API: {data: [{id, status, path, in_cache}]}
modelsStore->>modelsStore: routerModels = $state(data)
modelsStore->>modelsStore: fetchModalitiesForLoadedModels()
loop each model where status === "loaded"
modelsStore->>PropsSvc: fetchForModel(modelId)
PropsSvc->>API: GET /props?model={modelId}
API-->>PropsSvc: ApiLlamaCppServerProps
modelsStore->>modelsStore: modelPropsCache.set(modelId, props)
end
modelsStore->>modelsStore: propsCacheVersion++
deactivate modelsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🔄 MODEL SELECTION (ROUTER mode)
%% ═══════════════════════════════════════════════════════════════════════════
UI->>Hooks: useModelChangeValidation({getRequiredModalities, onSuccess?, onValidationFailure?})
Note over Hooks: Hook configured per-component:<br/>ChatForm: getRequiredModalities = usedModalities<br/>ChatMessage: getRequiredModalities = getModalitiesUpToMessage(msgId)
UI->>Hooks: handleModelChange(modelId, modelName)
activate Hooks
Hooks->>Hooks: previousSelectedModelId = modelsStore.selectedModelId
Hooks->>modelsStore: isModelLoaded(modelName)?
alt model NOT loaded
Hooks->>modelsStore: loadModel(modelName)
Note over modelsStore: → see LOAD MODEL section below
end
Note over Hooks: Always fetch props (from cache or API)
Hooks->>modelsStore: fetchModelProps(modelName)
modelsStore-->>Hooks: props
Hooks->>convStore: getRequiredModalities()
convStore-->>Hooks: {vision, audio}
Hooks->>Hooks: Validate: model.modalities ⊇ required?
alt validation PASSED
Hooks->>modelsStore: selectModelById(modelId)
Hooks-->>UI: return true
else validation FAILED
Hooks->>UI: toast.error("Model doesn't support required modalities")
alt model was just loaded
Hooks->>modelsStore: unloadModel(modelName)
end
alt onValidationFailure provided
Hooks->>modelsStore: selectModelById(previousSelectedModelId)
end
Hooks-->>UI: return false
end
deactivate Hooks
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ⬆️ LOAD MODEL (ROUTER mode)
%% ═══════════════════════════════════════════════════════════════════════════
modelsStore->>modelsStore: loadModel(modelId)
activate modelsStore
alt already loaded
modelsStore-->>modelsStore: return (no-op)
end
modelsStore->>modelsStore: modelLoadingStates.set(modelId, true)
modelsStore->>ModelsSvc: load(modelId)
ModelsSvc->>API: POST /models/load {model: modelId}
API-->>ModelsSvc: {status: "loading"}
modelsStore->>modelsStore: pollForModelStatus(modelId, LOADED)
loop poll every 500ms (max 60 attempts)
modelsStore->>modelsStore: fetchRouterModels()
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /v1/models
API-->>ModelsSvc: models[]
modelsStore->>modelsStore: getModelStatus(modelId)
alt status === LOADED
Note right of modelsStore: break loop
else status === LOADING
Note right of modelsStore: wait 500ms, continue
end
end
modelsStore->>modelsStore: updateModelModalities(modelId)
modelsStore->>PropsSvc: fetchForModel(modelId)
PropsSvc->>API: GET /props?model={modelId}
API-->>PropsSvc: props with modalities
modelsStore->>modelsStore: modelPropsCache.set(modelId, props)
modelsStore->>modelsStore: propsCacheVersion++
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
deactivate modelsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ⬇️ UNLOAD MODEL (ROUTER mode)
%% ═══════════════════════════════════════════════════════════════════════════
modelsStore->>modelsStore: unloadModel(modelId)
activate modelsStore
modelsStore->>modelsStore: modelLoadingStates.set(modelId, true)
modelsStore->>ModelsSvc: unload(modelId)
ModelsSvc->>API: POST /models/unload {model: modelId}
modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED)
loop poll until unloaded
modelsStore->>ModelsSvc: listRouter()
ModelsSvc->>API: GET /v1/models
end
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
deactivate modelsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 📊 COMPUTED GETTERS
%% ═══════════════════════════════════════════════════════════════════════════
Note over modelsStore: Getters:<br/>- selectedModel: ModelOption | null<br/>- loadedModelIds: string[] (from routerModels)<br/>- loadingModelIds: string[] (from modelLoadingStates)<br/>- singleModelName: string | null (MODEL mode only)
Note over modelsStore: Modality helpers:<br/>- getModelModalities(modelId): {vision, audio}<br/>- modelSupportsVision(modelId): boolean<br/>- modelSupportsAudio(modelId): boolean
```

View File

@@ -0,0 +1,76 @@
```mermaid
sequenceDiagram
participant UI as 🧩 +layout.svelte
participant serverStore as 🗄️ serverStore
participant PropsSvc as ⚙️ PropsService
participant API as 🌐 llama-server
Note over serverStore: State:<br/>props: ApiLlamaCppServerProps | null<br/>loading, error<br/>role: ServerRole | null (MODEL | ROUTER)<br/>fetchPromise (deduplication)
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🚀 INITIALIZATION
%% ═══════════════════════════════════════════════════════════════════════════
UI->>serverStore: fetch()
activate serverStore
alt fetchPromise exists (already fetching)
serverStore-->>UI: return fetchPromise
Note right of serverStore: Deduplicate concurrent calls
end
serverStore->>serverStore: loading = true
serverStore->>serverStore: fetchPromise = new Promise()
serverStore->>PropsSvc: fetch()
PropsSvc->>API: GET /props
API-->>PropsSvc: ApiLlamaCppServerProps
Note right of API: {role, model_path, model_alias,<br/>modalities, default_generation_settings, ...}
PropsSvc-->>serverStore: props
serverStore->>serverStore: props = $state(data)
serverStore->>serverStore: detectRole(props)
Note right of serverStore: role = props.role === "router"<br/> ? ServerRole.ROUTER<br/> : ServerRole.MODEL
serverStore->>serverStore: loading = false
serverStore->>serverStore: fetchPromise = null
deactivate serverStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 📊 COMPUTED GETTERS
%% ═══════════════════════════════════════════════════════════════════════════
Note over serverStore: Getters from props:
rect rgb(240, 255, 240)
Note over serverStore: defaultParams<br/>→ props.default_generation_settings.params<br/>(temperature, top_p, top_k, etc.)
end
rect rgb(240, 255, 240)
Note over serverStore: contextSize<br/>→ props.default_generation_settings.n_ctx
end
rect rgb(255, 240, 240)
Note over serverStore: isRouterMode<br/>→ role === ServerRole.ROUTER
end
rect rgb(255, 240, 240)
Note over serverStore: isModelMode<br/>→ role === ServerRole.MODEL
end
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: 🔗 RELATIONSHIPS
%% ═══════════════════════════════════════════════════════════════════════════
Note over serverStore: Used by:
Note right of serverStore: - modelsStore: role detection, MODEL mode modalities<br/>- settingsStore: syncWithServerDefaults (defaultParams)<br/>- chatStore: contextSize for processing state<br/>- UI components: isRouterMode for conditional rendering
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,API: ❌ ERROR HANDLING
%% ═══════════════════════════════════════════════════════════════════════════
Note over serverStore: getErrorMessage(): string | null<br/>Returns formatted error for UI display
Note over serverStore: clear(): void<br/>Resets all state (props, error, loading, role)
```

View File

@@ -0,0 +1,156 @@
```mermaid
sequenceDiagram
participant UI as 🧩 ChatSettings
participant settingsStore as 🗄️ settingsStore
participant serverStore as 🗄️ serverStore
participant ParamSvc as ⚙️ ParameterSyncService
participant LS as 💾 LocalStorage
Note over settingsStore: State:<br/>config: SettingsConfigType<br/>theme: string ("auto" | "light" | "dark")<br/>isInitialized: boolean<br/>userOverrides: Set&lt;string&gt;
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 🚀 INITIALIZATION
%% ═══════════════════════════════════════════════════════════════════════════
Note over settingsStore: Auto-initialized in constructor (browser only)
settingsStore->>settingsStore: initialize()
activate settingsStore
settingsStore->>settingsStore: loadConfig()
settingsStore->>LS: get("llama-config")
LS-->>settingsStore: StoredConfig | null
alt config exists
settingsStore->>settingsStore: Merge with SETTING_CONFIG_DEFAULT
Note right of settingsStore: Fill missing keys with defaults
else no config
settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT
end
settingsStore->>LS: get("llama-userOverrides")
LS-->>settingsStore: string[] | null
settingsStore->>settingsStore: userOverrides = new Set(data)
settingsStore->>settingsStore: loadTheme()
settingsStore->>LS: get("llama-theme")
LS-->>settingsStore: theme | "auto"
settingsStore->>settingsStore: isInitialized = true
deactivate settingsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 🔄 SYNC WITH SERVER DEFAULTS
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI: Triggered from +layout.svelte when serverStore.props loaded
UI->>settingsStore: syncWithServerDefaults()
activate settingsStore
settingsStore->>serverStore: defaultParams
serverStore-->>settingsStore: {temperature, top_p, top_k, ...}
loop each SYNCABLE_PARAMETER
alt key NOT in userOverrides
settingsStore->>settingsStore: config[key] = serverDefault[key]
Note right of settingsStore: Non-overridden params adopt server default
else key in userOverrides
Note right of settingsStore: Keep user value, skip server default
end
end
alt serverStore.props has webuiSettings
settingsStore->>settingsStore: Apply webuiSettings from server
Note right of settingsStore: Server-provided UI settings<br/>(e.g. showRawOutputSwitch)
end
settingsStore->>settingsStore: saveConfig()
deactivate settingsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: ⚙️ UPDATE CONFIG
%% ═══════════════════════════════════════════════════════════════════════════
UI->>settingsStore: updateConfig(key, value)
activate settingsStore
settingsStore->>settingsStore: config[key] = value
alt value matches server default for key
settingsStore->>settingsStore: userOverrides.delete(key)
Note right of settingsStore: Matches server default, remove override
else value differs from server default
settingsStore->>settingsStore: userOverrides.add(key)
Note right of settingsStore: Mark as user-modified (won't be overwritten)
end
settingsStore->>settingsStore: saveConfig()
settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config)
settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides])
deactivate settingsStore
UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2})
activate settingsStore
Note right of settingsStore: Batch update, single save
settingsStore->>settingsStore: For each key: config[key] = value
settingsStore->>settingsStore: For each key: userOverrides.add(key)
settingsStore->>settingsStore: saveConfig()
deactivate settingsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 🔄 RESET
%% ═══════════════════════════════════════════════════════════════════════════
UI->>settingsStore: resetConfig()
activate settingsStore
settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT}
settingsStore->>settingsStore: userOverrides.clear()
Note right of settingsStore: All params reset to defaults<br/>Next syncWithServerDefaults will adopt server values
settingsStore->>settingsStore: saveConfig()
deactivate settingsStore
UI->>settingsStore: resetParameterToServerDefault(key)
activate settingsStore
settingsStore->>settingsStore: userOverrides.delete(key)
settingsStore->>serverStore: defaultParams[key]
settingsStore->>settingsStore: config[key] = serverDefault
settingsStore->>settingsStore: saveConfig()
deactivate settingsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 🎨 THEME
%% ═══════════════════════════════════════════════════════════════════════════
UI->>settingsStore: updateTheme(newTheme)
activate settingsStore
settingsStore->>settingsStore: theme = newTheme
settingsStore->>settingsStore: saveTheme()
settingsStore->>LS: set("llama-theme", theme)
deactivate settingsStore
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 📊 PARAMETER INFO
%% ═══════════════════════════════════════════════════════════════════════════
UI->>settingsStore: getParameterInfo(key)
settingsStore->>ParamSvc: getParameterInfo(key, config, serverDefaults, userOverrides)
ParamSvc-->>settingsStore: ParameterInfo
Note right of ParamSvc: {<br/> currentValue,<br/> serverDefault,<br/> isUserOverride: boolean,<br/> canSync: boolean,<br/> isDifferentFromServer: boolean<br/>}
UI->>settingsStore: getParameterDiff()
settingsStore->>ParamSvc: createParameterDiff(config, serverDefaults, userOverrides)
ParamSvc-->>settingsStore: ParameterDiff[]
Note right of ParamSvc: Array of parameters where user != server
%% ═══════════════════════════════════════════════════════════════════════════
Note over UI,LS: 📋 CONFIG CATEGORIES
%% ═══════════════════════════════════════════════════════════════════════════
Note over settingsStore: Syncable with server (from /props):
rect rgb(240, 255, 240)
Note over settingsStore: temperature, top_p, top_k, min_p<br/>repeat_penalty, presence_penalty, frequency_penalty<br/>dynatemp_range, dynatemp_exponent<br/>typ_p, xtc_probability, xtc_threshold<br/>dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n
end
Note over settingsStore: UI-only (not synced):
rect rgb(255, 240, 240)
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch
end
```

View File

@@ -0,0 +1,51 @@
// For more info, see https://github.com/storybookjs/eslint-plugin-storybook#configuration-flat-config-format
import storybook from 'eslint-plugin-storybook';
import prettier from 'eslint-config-prettier';
import { includeIgnoreFile } from '@eslint/compat';
import js from '@eslint/js';
import svelte from 'eslint-plugin-svelte';
import globals from 'globals';
import { fileURLToPath } from 'node:url';
import ts from 'typescript-eslint';
import svelteConfig from './svelte.config.js';
const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url));
export default ts.config(
includeIgnoreFile(gitignorePath),
js.configs.recommended,
...ts.configs.recommended,
...svelte.configs.recommended,
prettier,
...svelte.configs.prettier,
{
languageOptions: {
globals: { ...globals.browser, ...globals.node }
},
rules: {
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
'no-undef': 'off',
'svelte/no-at-html-tags': 'off',
// This app uses hash-based routing (#/) where resolve() from $app/paths does not apply
'svelte/no-navigation-without-resolve': 'off'
}
},
{
files: ['**/*.svelte', '**/*.svelte.ts', '**/*.svelte.js'],
languageOptions: {
parserOptions: {
projectService: true,
extraFileExtensions: ['.svelte'],
parser: ts.parser,
svelteConfig
}
}
},
{
// Exclude Storybook files from main ESLint rules
ignores: ['.storybook/**/*']
},
storybook.configs['flat/recommended']
);

10704
tools/server/webui/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
{
"name": "llama-server-webui",
"private": true,
"version": "1.0.0",
"type": "module",
"scripts": {
"dev": "bash scripts/dev.sh",
"build": "vite build && ./scripts/post-build.sh",
"preview": "vite preview",
"prepare": "svelte-kit sync || echo ''",
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
"reset": "rm -rf .svelte-kit node_modules",
"format": "prettier --write .",
"lint": "prettier --check . && eslint .",
"test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:unit -- --run && npm run test:e2e",
"test:e2e": "playwright test",
"test:client": "vitest --project=client",
"test:unit": "vitest --project=unit",
"test:ui": "vitest --project=ui",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"cleanup": "rm -rf .svelte-kit build node_modules test-results"
},
"devDependencies": {
"@chromatic-com/storybook": "^5.0.0",
"@eslint/compat": "^1.2.5",
"@eslint/js": "^9.18.0",
"@internationalized/date": "^3.10.1",
"@lucide/svelte": "^0.515.0",
"@playwright/test": "^1.49.1",
"@storybook/addon-a11y": "^10.2.4",
"@storybook/addon-docs": "^10.2.4",
"@storybook/addon-svelte-csf": "^5.0.10",
"@storybook/addon-vitest": "^10.2.4",
"@storybook/sveltekit": "^10.2.4",
"@sveltejs/adapter-static": "^3.0.10",
"@sveltejs/kit": "^2.48.4",
"@sveltejs/vite-plugin-svelte": "^6.2.1",
"@tailwindcss/forms": "^0.5.9",
"@tailwindcss/typography": "^0.5.15",
"@tailwindcss/vite": "^4.0.0",
"@types/node": "^24",
"@vitest/browser": "^3.2.3",
"@vitest/coverage-v8": "^3.2.3",
"bits-ui": "^2.14.4",
"clsx": "^2.1.1",
"dexie": "^4.0.11",
"eslint": "^9.18.0",
"eslint-config-prettier": "^10.0.1",
"eslint-plugin-storybook": "^10.2.4",
"eslint-plugin-svelte": "^3.0.0",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.56.1",
"prettier": "^3.4.2",
"prettier-plugin-svelte": "^3.3.3",
"prettier-plugin-tailwindcss": "^0.6.11",
"rehype-katex": "^7.0.1",
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^10.2.4",
"svelte": "^5.38.2",
"svelte-check": "^4.0.0",
"tailwind-merge": "^3.3.1",
"tailwind-variants": "^3.2.2",
"tailwindcss": "^4.0.0",
"tw-animate-css": "^1.3.5",
"typescript": "^5.0.0",
"typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0",
"vite": "^7.2.2",
"vite-plugin-devtools-json": "^0.2.0",
"vitest": "^3.2.3",
"vitest-browser-svelte": "^0.1.0"
},
"dependencies": {
"@modelcontextprotocol/sdk": "^1.25.1",
"highlight.js": "^11.11.1",
"mode-watcher": "^1.1.0",
"pdfjs-dist": "^5.4.54",
"rehype-highlight": "^7.0.2",
"rehype-stringify": "^10.0.1",
"remark": "^15.0.1",
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.1",
"remark-html": "^16.0.1",
"remark-rehype": "^11.1.2",
"svelte-sonner": "^1.0.5",
"unist-util-visit": "^5.0.0",
"zod": "^4.2.1"
}
}

View File

@@ -0,0 +1,11 @@
import { defineConfig } from '@playwright/test';
export default defineConfig({
webServer: {
command: 'npm run build && http-server ../public -p 8181',
port: 8181,
timeout: 120000,
reuseExistingServer: false
},
testDir: 'tests/e2e'
});

View File

@@ -0,0 +1,57 @@
#!/bin/bash
# Development script for llama.cpp webui
#
# This script starts the webui development servers (Storybook and Vite).
# Note: You need to start llama-server separately.
#
# Usage:
# bash scripts/dev.sh
# npm run dev
cd ../../../
# Check and install git hooks if missing
check_and_install_hooks() {
local hooks_missing=false
# Check for required hooks
if [ ! -f ".git/hooks/pre-commit" ] || [ ! -f ".git/hooks/pre-push" ] || [ ! -f ".git/hooks/post-push" ]; then
hooks_missing=true
fi
if [ "$hooks_missing" = true ]; then
echo "🔧 Git hooks missing, installing them..."
cd tools/server/webui
if bash scripts/install-git-hooks.sh; then
echo "✅ Git hooks installed successfully"
else
echo "⚠️ Failed to install git hooks, continuing anyway..."
fi
cd ../../../
else
echo "✅ Git hooks already installed"
fi
}
# Install git hooks if needed
check_and_install_hooks
# Cleanup function
cleanup() {
echo "🧹 Cleaning up..."
exit
}
# Set up signal handlers
trap cleanup SIGINT SIGTERM
echo "🚀 Starting development servers..."
echo "📝 Note: Make sure to start llama-server separately if needed"
cd tools/server/webui
# Use --insecure-http-parser to handle malformed HTTP responses from llama-server
# (some responses have both Content-Length and Transfer-Encoding headers)
storybook dev -p 6006 --ci & NODE_OPTIONS="--insecure-http-parser" vite dev --host 0.0.0.0 &
# Wait for all background processes
wait

View File

@@ -0,0 +1,82 @@
#!/bin/bash
# Script to install pre-commit hook for webui
# Pre-commit: formats, checks, builds, and stages build output
REPO_ROOT=$(git rev-parse --show-toplevel)
PRE_COMMIT_HOOK="$REPO_ROOT/.git/hooks/pre-commit"
echo "Installing pre-commit hook for webui..."
# Create the pre-commit hook
cat > "$PRE_COMMIT_HOOK" << 'EOF'
#!/bin/bash
# Check if there are any changes in the webui directory
if git diff --cached --name-only | grep -q "^tools/server/webui/"; then
REPO_ROOT=$(git rev-parse --show-toplevel)
cd "$REPO_ROOT/tools/server/webui"
# Check if package.json exists
if [ ! -f "package.json" ]; then
echo "Error: package.json not found in tools/server/webui"
exit 1
fi
echo "Formatting and checking webui code..."
# Run the format command
npm run format
if [ $? -ne 0 ]; then
echo "Error: npm run format failed"
exit 1
fi
# Run the lint command
npm run lint
if [ $? -ne 0 ]; then
echo "Error: npm run lint failed"
exit 1
fi
# Run the check command
npm run check
if [ $? -ne 0 ]; then
echo "Error: npm run check failed"
exit 1
fi
echo "✅ Webui code formatted and checked successfully"
# Build the webui
echo "Building webui..."
npm run build
if [ $? -ne 0 ]; then
echo "❌ npm run build failed"
exit 1
fi
# Stage the build output alongside the source changes
cd "$REPO_ROOT"
git add tools/server/public/
echo "✅ Webui built and build output staged"
fi
exit 0
EOF
# Make hook executable
chmod +x "$PRE_COMMIT_HOOK"
if [ $? -eq 0 ]; then
echo "✅ Git hook installed successfully!"
echo " Pre-commit: $PRE_COMMIT_HOOK"
echo ""
echo "The hook will automatically:"
echo " • Format, lint and check webui code before commits"
echo " • Build webui and stage tools/server/public/ into the same commit"
else
echo "❌ Failed to make hook executable"
exit 1
fi

View File

@@ -0,0 +1,3 @@
rm -rf ../public/_app;
rm ../public/favicon.svg;
rm -f ../public/index.html.gz; # deprecated, but may still be generated by older versions of the build process

View File

@@ -0,0 +1,84 @@
import { readFileSync, writeFileSync, existsSync, readdirSync, copyFileSync } from 'fs';
import { resolve } from 'path';
import type { Plugin } from 'vite';
const GUIDE_FOR_FRONTEND = `
<!--
This is a static build of the frontend.
It is automatically generated by the build process.
Do not edit this file directly.
To make changes, refer to the "Web UI" section in the README.
-->
`.trim();
export function llamaCppBuildPlugin(): Plugin {
return {
name: 'llamacpp:build',
apply: 'build',
closeBundle() {
// Ensure the SvelteKit adapter has finished writing to ../public
setTimeout(() => {
try {
const indexPath = resolve('../public/index.html');
if (!existsSync(indexPath)) return;
let content = readFileSync(indexPath, 'utf-8');
const faviconPath = resolve('static/favicon.svg');
if (existsSync(faviconPath)) {
const faviconContent = readFileSync(faviconPath, 'utf-8');
const faviconBase64 = Buffer.from(faviconContent).toString('base64');
const faviconDataUrl = `data:image/svg+xml;base64,${faviconBase64}`;
content = content.replace(/href="[^"]*favicon\.svg"/g, `href="${faviconDataUrl}"`);
console.log('✓ Inlined favicon.svg as base64 data URL');
}
content = content.replace(/\r/g, '');
content = GUIDE_FOR_FRONTEND + '\n' + content;
content = content.replace(/\/_app\/immutable\/bundle\.[^"]+\.js/g, './bundle.js');
content = content.replace(
/\/_app\/immutable\/assets\/bundle\.[^"]+\.css/g,
'./bundle.css'
);
content = content.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
writeFileSync(indexPath, content, 'utf-8');
console.log('✓ Updated index.html');
// Copy bundle.*.js -> ../public/bundle.js
const immutableDir = resolve('../public/_app/immutable');
const bundleDir = resolve('../public/_app/immutable/assets');
if (existsSync(immutableDir)) {
const jsFiles = readdirSync(immutableDir).filter((f) => f.match(/^bundle\..+\.js$/));
if (jsFiles.length > 0) {
copyFileSync(resolve(immutableDir, jsFiles[0]), resolve('../public/bundle.js'));
// Normalize __sveltekit_<hash> to __sveltekit__ in bundle.js
const bundleJsPath = resolve('../public/bundle.js');
let bundleJs = readFileSync(bundleJsPath, 'utf-8');
bundleJs = bundleJs.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
writeFileSync(bundleJsPath, bundleJs, 'utf-8');
console.log(`✓ Copied ${jsFiles[0]} -> bundle.js`);
}
}
// Copy bundle.*.css -> ../public/bundle.css
if (existsSync(bundleDir)) {
const cssFiles = readdirSync(bundleDir).filter((f) => f.match(/^bundle\..+\.css$/));
if (cssFiles.length > 0) {
copyFileSync(resolve(bundleDir, cssFiles[0]), resolve('../public/bundle.css'));
console.log(`✓ Copied ${cssFiles[0]} -> bundle.css`);
}
}
} catch (error) {
console.error('Failed to update index.html:', error);
}
}, 100);
}
};
}

View File

@@ -0,0 +1,186 @@
@import 'tailwindcss';
@import 'tw-animate-css';
@custom-variant dark (&:is(.dark *));
:root {
--radius: 0.625rem;
--background: oklch(1 0 0);
--foreground: oklch(0.145 0 0);
--card: oklch(1 0 0);
--card-foreground: oklch(0.145 0 0);
--popover: oklch(1 0 0);
--popover-foreground: oklch(0.145 0 0);
--primary: oklch(0.205 0 0);
--primary-foreground: oklch(0.985 0 0);
--secondary: oklch(0.95 0 0);
--secondary-foreground: oklch(0.205 0 0);
--muted: oklch(0.97 0 0);
--muted-foreground: oklch(0.556 0 0);
--accent: oklch(0.95 0 0);
--accent-foreground: oklch(0.205 0 0);
--destructive: oklch(0.577 0.245 27.325);
--border: oklch(0.875 0 0);
--input: oklch(0.92 0 0);
--ring: oklch(0.708 0 0);
--chart-1: oklch(0.646 0.222 41.116);
--chart-2: oklch(0.6 0.118 184.704);
--chart-3: oklch(0.398 0.07 227.392);
--chart-4: oklch(0.828 0.189 84.429);
--chart-5: oklch(0.769 0.188 70.08);
--sidebar: oklch(0.985 0 0);
--sidebar-foreground: oklch(0.145 0 0);
--sidebar-primary: oklch(0.205 0 0);
--sidebar-primary-foreground: oklch(0.985 0 0);
--sidebar-accent: oklch(0.97 0 0);
--sidebar-accent-foreground: oklch(0.205 0 0);
--sidebar-border: oklch(0.922 0 0);
--sidebar-ring: oklch(0.708 0 0);
--code-background: oklch(0.985 0 0);
--code-foreground: oklch(0.145 0 0);
--layer-popover: 1000000;
--chat-form-area-height: 8rem;
--chat-form-area-offset: 2rem;
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
}
@media (min-width: 640px) {
:root {
--chat-form-area-height: 24rem;
--chat-form-area-offset: 12rem;
}
}
.dark {
--background: oklch(0.16 0 0);
--foreground: oklch(0.985 0 0);
--card: oklch(0.205 0 0);
--card-foreground: oklch(0.985 0 0);
--popover: oklch(0.205 0 0);
--popover-foreground: oklch(0.985 0 0);
--primary: oklch(0.922 0 0);
--primary-foreground: oklch(0.205 0 0);
--secondary: oklch(0.29 0 0);
--secondary-foreground: oklch(0.985 0 0);
--muted: oklch(0.269 0 0);
--muted-foreground: oklch(0.708 0 0);
--accent: oklch(0.269 0 0);
--accent-foreground: oklch(0.985 0 0);
--destructive: oklch(0.704 0.191 22.216);
--border: oklch(1 0 0 / 30%);
--input: oklch(1 0 0 / 30%);
--ring: oklch(0.556 0 0);
--chart-1: oklch(0.488 0.243 264.376);
--chart-2: oklch(0.696 0.17 162.48);
--chart-3: oklch(0.769 0.188 70.08);
--chart-4: oklch(0.627 0.265 303.9);
--chart-5: oklch(0.645 0.246 16.439);
--sidebar: oklch(0.2 0 0);
--sidebar-foreground: oklch(0.985 0 0);
--sidebar-primary: oklch(0.488 0.243 264.376);
--sidebar-primary-foreground: oklch(0.985 0 0);
--sidebar-accent: oklch(0.269 0 0);
--sidebar-accent-foreground: oklch(0.985 0 0);
--sidebar-border: oklch(1 0 0 / 10%);
--sidebar-ring: oklch(0.556 0 0);
--code-background: oklch(0.225 0 0);
--code-foreground: oklch(0.875 0 0);
}
@theme inline {
--radius-sm: calc(var(--radius) - 4px);
--radius-md: calc(var(--radius) - 2px);
--radius-lg: var(--radius);
--radius-xl: calc(var(--radius) + 4px);
--color-background: var(--background);
--color-foreground: var(--foreground);
--color-card: var(--card);
--color-card-foreground: var(--card-foreground);
--color-popover: var(--popover);
--color-popover-foreground: var(--popover-foreground);
--color-primary: var(--primary);
--color-primary-foreground: var(--primary-foreground);
--color-secondary: var(--secondary);
--color-secondary-foreground: var(--secondary-foreground);
--color-muted: var(--muted);
--color-muted-foreground: var(--muted-foreground);
--color-accent: var(--accent);
--color-accent-foreground: var(--accent-foreground);
--color-destructive: var(--destructive);
--color-border: var(--border);
--color-input: var(--input);
--color-ring: var(--ring);
--color-chart-1: var(--chart-1);
--color-chart-2: var(--chart-2);
--color-chart-3: var(--chart-3);
--color-chart-4: var(--chart-4);
--color-chart-5: var(--chart-5);
--color-sidebar: var(--sidebar);
--color-sidebar-foreground: var(--sidebar-foreground);
--color-sidebar-primary: var(--sidebar-primary);
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
--color-sidebar-accent: var(--sidebar-accent);
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
--color-sidebar-border: var(--sidebar-border);
--color-sidebar-ring: var(--sidebar-ring);
}
@layer base {
* {
@apply border-border outline-ring/50;
}
body {
@apply bg-background text-foreground;
scrollbar-width: thin;
scrollbar-gutter: stable;
}
/* Global scrollbar styling - visible only on hover */
* {
scrollbar-width: thin;
scrollbar-color: transparent transparent;
transition: scrollbar-color 0.2s ease;
}
*:hover {
scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent;
}
*::-webkit-scrollbar {
width: 6px;
height: 6px;
}
*::-webkit-scrollbar-track {
background: transparent;
}
*::-webkit-scrollbar-thumb {
background: transparent;
border-radius: 3px;
transition: background 0.2s ease;
}
*:hover::-webkit-scrollbar-thumb {
background: hsl(var(--muted-foreground) / 0.3);
}
*::-webkit-scrollbar-thumb:hover {
background: hsl(var(--muted-foreground) / 0.5);
}
}
@layer utilities {
.scrollbar-hide {
/* Hide scrollbar for Chrome, Safari and Opera */
&::-webkit-scrollbar {
display: none;
}
/* Hide scrollbar for IE, Edge and Firefox */
-ms-overflow-style: none;
scrollbar-width: none;
}
}

131
tools/server/webui/src/app.d.ts vendored Normal file
View File

@@ -0,0 +1,131 @@
// See https://svelte.dev/docs/kit/types#app.d.ts
// for information about these interfaces
// Import chat types from dedicated module
import type {
// API types
ApiChatCompletionRequest,
ApiChatCompletionResponse,
ApiChatCompletionStreamChunk,
ApiChatCompletionToolCall,
ApiChatCompletionToolCallDelta,
ApiChatMessageData,
ApiChatMessageContentPart,
ApiContextSizeError,
ApiErrorResponse,
ApiLlamaCppServerProps,
ApiModelDataEntry,
ApiModelListResponse,
ApiProcessingState,
ApiRouterModelMeta,
ApiRouterModelsLoadRequest,
ApiRouterModelsLoadResponse,
ApiRouterModelsStatusRequest,
ApiRouterModelsStatusResponse,
ApiRouterModelsListResponse,
ApiRouterModelsUnloadRequest,
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatMessageType,
ChatRole,
ChatUploadedFile,
ChatMessageSiblingInfo,
ChatMessagePromptProgress,
ChatMessageTimings,
// Database types
DatabaseConversation,
DatabaseMessage,
DatabaseMessageExtra,
DatabaseMessageExtraAudioFile,
DatabaseMessageExtraImageFile,
DatabaseMessageExtraTextFile,
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraLegacyContext,
ExportedConversation,
ExportedConversations,
// Model types
ModelModalities,
ModelOption,
// Settings types
SettingsChatServiceOptions,
SettingsConfigValue,
SettingsFieldConfig,
SettingsConfigType
} from '$lib/types';
import { ServerRole, ServerModelStatus, ModelModality } from '$lib/enums';
declare global {
// namespace App {
// interface Error {}
// interface Locals {}
// interface PageData {}
// interface PageState {}
// interface Platform {}
// }
export {
// API types
ApiChatCompletionRequest,
ApiChatCompletionResponse,
ApiChatCompletionStreamChunk,
ApiChatCompletionToolCall,
ApiChatCompletionToolCallDelta,
ApiChatMessageData,
ApiChatMessageContentPart,
ApiContextSizeError,
ApiErrorResponse,
ApiLlamaCppServerProps,
ApiModelDataEntry,
ApiModelListResponse,
ApiProcessingState,
ApiRouterModelMeta,
ApiRouterModelsLoadRequest,
ApiRouterModelsLoadResponse,
ApiRouterModelsStatusRequest,
ApiRouterModelsStatusResponse,
ApiRouterModelsListResponse,
ApiRouterModelsUnloadRequest,
ApiRouterModelsUnloadResponse,
// Chat types
ChatAttachmentDisplayItem,
ChatMessagePromptProgress,
ChatMessageSiblingInfo,
ChatMessageTimings,
ChatMessageType,
ChatRole,
ChatUploadedFile,
// Database types
DatabaseConversation,
DatabaseMessage,
DatabaseMessageExtra,
DatabaseMessageExtraAudioFile,
DatabaseMessageExtraImageFile,
DatabaseMessageExtraTextFile,
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraLegacyContext,
ExportedConversation,
ExportedConversations,
// Enum types
ModelModality,
ServerRole,
ServerModelStatus,
// Model types
ModelModalities,
ModelOption,
// Settings types
SettingsChatServiceOptions,
SettingsConfigValue,
SettingsFieldConfig,
SettingsConfigType
};
}
declare global {
interface Window {
idxThemeStyle?: number;
idxCodeBlock?: number;
}
}

View File

@@ -0,0 +1,12 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<link rel="icon" href="%sveltekit.assets%/favicon.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
%sveltekit.head%
</head>
<body data-sveltekit-preload-data="hover">
<div style="display: contents">%sveltekit.body%</div>
</body>
</html>

View File

@@ -0,0 +1,47 @@
import { isElementInViewport } from '$lib/utils/viewport';
/**
* Svelte action that fades in an element when it enters the viewport.
* Uses IntersectionObserver for efficient viewport detection.
*
* If skipIfVisible is set and the element is already visible in the viewport
* when the action attaches (e.g. a markdown block promoted from unstable
* during streaming), the fade is skipped entirely to avoid a flash.
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
if (skipIfVisible && isElementInViewport(node)) {
return;
}
node.style.opacity = '0';
node.style.transform = `translateY(${y}px)`;
node.style.transition = `opacity ${duration}ms ease-out, transform ${duration}ms ease-out`;
$effect(() => {
const observer = new IntersectionObserver(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
observer.disconnect();
}
}
},
{ threshold: 0.05 }
);
observer.observe(node);
return () => {
observer.disconnect();
};
});
}

View File

@@ -0,0 +1,11 @@
---
name: app
description: Opinionated app components building on top of ./ui primitives
---
- Can include business logic and state management
- Can include data fetching and caching logic
- Should use original spelling for HTML-native events and `camelCase` for custom events
- Props and markup attributes should be listed alphabetically
- Use JS Objects and Arrays for CSS classes and styles when they are dynamic
- Whenever there can be repetition in the component's markup, if it's too small to be decoupled as a separate component — use Svelte 5's `{#snippet}` + `{@render}`

View File

@@ -0,0 +1,60 @@
<script lang="ts">
import { Button, type ButtonVariant, type ButtonSize } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import type { Component } from 'svelte';
import { TooltipSide } from '$lib/enums';
interface Props {
ariaLabel?: string;
class?: string;
disabled?: boolean;
icon: Component;
iconSize?: string;
onclick: (e?: MouseEvent) => void;
size?: ButtonSize;
stopPropagationOnClick?: boolean;
tooltip: string;
variant?: ButtonVariant;
tooltipSide?: TooltipSide;
}
let {
icon,
tooltip,
variant = 'ghost',
size = 'sm',
class: className = '',
disabled = false,
iconSize = 'h-3 w-3',
tooltipSide = TooltipSide.TOP,
stopPropagationOnClick = false,
onclick,
ariaLabel
}: Props = $props();
</script>
<Tooltip.Root>
<Tooltip.Trigger>
<Button
{variant}
{size}
{disabled}
onclick={(e: MouseEvent) => {
if (stopPropagationOnClick) e.stopPropagation();
onclick?.(e);
}}
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{#if icon}
{@const IconComponent = icon}
<IconComponent class={iconSize} />
{/if}
</Button>
</Tooltip.Trigger>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>

View File

@@ -0,0 +1,17 @@
<script lang="ts">
import { Copy } from '@lucide/svelte';
import { copyToClipboard } from '$lib/utils';
import ActionIcon from './ActionIcon.svelte';
export let ariaLabel: string = 'Copy to clipboard';
export let canCopy: boolean = true;
export let text: string;
</script>
<ActionIcon
icon={Copy}
tooltip={ariaLabel}
iconSize="h-4 w-4"
disabled={!canCopy}
onclick={() => canCopy && copyToClipboard(text)}
/>

Some files were not shown because too many files have changed in this diff Show More