llama.cpp verification source 2026-05-22
Some checks are pending
Copilot Setup Steps / copilot-setup-steps (push) Waiting to run
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Waiting to run
Python check requirements.txt / check-requirements (push) Waiting to run
Python Type-Check / python type-check (push) Waiting to run
Update Operations Documentation / update-ops-docs (push) Waiting to run
Some checks are pending
Copilot Setup Steps / copilot-setup-steps (push) Waiting to run
Check Pre-Tokenizer Hashes / pre-tokenizer-hashes (push) Waiting to run
Python check requirements.txt / check-requirements (push) Waiting to run
Python Type-Check / python type-check (push) Waiting to run
Update Operations Documentation / update-ops-docs (push) Waiting to run
This commit is contained in:
75
tools/server/CMakeLists.txt
Normal file
75
tools/server/CMakeLists.txt
Normal file
@@ -0,0 +1,75 @@
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# server-context containing the core server logic, used by llama-server and CLI
|
||||
|
||||
set(TARGET server-context)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
server-chat.cpp
|
||||
server-chat.h
|
||||
server-task.cpp
|
||||
server-task.h
|
||||
server-queue.cpp
|
||||
server-queue.h
|
||||
server-common.cpp
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../mtmd)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PUBLIC llama-common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
|
||||
# llama-server executable
|
||||
|
||||
set(TARGET llama-server)
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
server-http.cpp
|
||||
server-http.h
|
||||
server-models.cpp
|
||||
server-models.h
|
||||
)
|
||||
|
||||
option(LLAMA_BUILD_WEBUI "Build the embedded Web UI" ON)
|
||||
|
||||
if (LLAMA_BUILD_WEBUI)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html
|
||||
bundle.js
|
||||
bundle.css
|
||||
loading.html
|
||||
)
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
|
||||
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||
list(APPEND TARGET_SRCS ${output})
|
||||
add_custom_command(
|
||||
DEPENDS "${input}"
|
||||
OUTPUT "${output}"
|
||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||
)
|
||||
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||
endforeach()
|
||||
add_definitions(-DLLAMA_BUILD_WEBUI)
|
||||
else()
|
||||
endif()
|
||||
|
||||
add_executable(${TARGET} ${TARGET_SRCS})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../mtmd)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PRIVATE server-context PUBLIC llama-common cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
264
tools/server/README-dev.md
Normal file
264
tools/server/README-dev.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# llama-server Development Documentation
|
||||
|
||||
This document provides an in-depth technical overview of `llama-server`, intended for maintainers and contributors.
|
||||
|
||||
If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead.
|
||||
|
||||
## Scope of features
|
||||
|
||||
In-scope types of feature:
|
||||
|
||||
- Backend:
|
||||
- Basic inference features: text completion, embeddings output
|
||||
- Chat-oriented features: chat completion, tool calling
|
||||
- Third-party API compatibility, e.g. OAI-compat, Anthropic-compat
|
||||
- Multimodal input/output
|
||||
- Memory management: save/load state, context checkpoints
|
||||
- Model management
|
||||
- Features that are required by the Web UI
|
||||
- Frontend:
|
||||
- Chat-oriented features, example: basic chat, image upload, edit messages
|
||||
- Agentic features, example: MCP
|
||||
- Model management
|
||||
|
||||
Note: For security reasons, features that require reading or writing external files must be **disabled by default**. This covers features like: MCP, model save/load
|
||||
|
||||
Out-of-scope features:
|
||||
|
||||
- Backend:
|
||||
- Features that require a loop of external API calls, e.g. server-side agentic loop. This is because external API calls in C++ are costly to maintain. Any complex third-party logic should be implemented outside of server code.
|
||||
- Features that expose the internal state of the model to the API, example: getting the intermediate activation from API. This is because llama.cpp doesn't support a stable API for doing this, and relying on `eval_callback` can make it complicated to maintain as this API is not intended to be used in multi-sequence setup.
|
||||
- Model-specific features. All API calls and features must remain model-agnostic.
|
||||
- Frontend:
|
||||
- Third-party plugins, it is costly to maintain a public plugin API for such features. Instead, users can make their own MCP server for their needs.
|
||||
- Customizable themes, it is also costly to maintain. While we do focus on the aesthetic, we try to achieve this by perfecting a small set of themes.
|
||||
- Browser-specific features, example: [Chrome's built-in AI API](https://developer.chrome.com/docs/ai/built-in-apis).
|
||||
|
||||
## Backend
|
||||
|
||||
### Overview
|
||||
|
||||
The server supports two primary operating modes:
|
||||
|
||||
- **Inference mode**: The default mode for performing inference with a single loaded GGUF model.
|
||||
- **Router mode**: Enables management of multiple inference server instances behind a single API endpoint. Requests are automatically routed to the appropriate backend instance based on the requested model.
|
||||
|
||||
The core architecture consists of the following components:
|
||||
|
||||
- `server_context`: Holds the primary inference state, including the main `llama_context` and all active slots.
|
||||
- `server_slot`: An abstraction over a single “sequence” in llama.cpp, responsible for managing individual parallel inference requests.
|
||||
- `server_routes`: Middleware layer between `server_context` and the HTTP interface; handles JSON parsing/formatting and request routing logic.
|
||||
- `server_http_context`: Implements the HTTP server using `cpp-httplib`.
|
||||
- `server_queue`: Thread-safe queue used by HTTP workers to submit new tasks to `server_context`.
|
||||
- `server_response`: Thread-safe queue used by `server_context` to return results to HTTP workers.
|
||||
- `server_response_reader`: Higher-level wrapper around the two queues above for cleaner code.
|
||||
- `server_task`: Unit of work pushed into `server_queue`.
|
||||
- `server_task_result`: Unit of result pushed into `server_response`.
|
||||
- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`.
|
||||
- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation.
|
||||
- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
API_User <--> server_http_context
|
||||
server_http_context <-- router mode --> server_models
|
||||
server_http_context <-- inference mode --> server_routes
|
||||
server_routes -- server_task --> server_queue
|
||||
subgraph server_context
|
||||
server_queue --> server_slot
|
||||
server_slot -- server_task_result --> server_response
|
||||
server_slot[multiple server_slot]
|
||||
end
|
||||
server_response --> server_routes
|
||||
```
|
||||
|
||||
### Batching
|
||||
|
||||
The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
|
||||
|
||||
Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
|
||||
|
||||
Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
|
||||
|
||||
Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
|
||||
|
||||
### Thread Management
|
||||
|
||||
`server_context` runs on a dedicated single thread. Because it is single-threaded, heavy post-processing (especially after token generation) should be avoided, as it directly impacts multi-sequence throughput.
|
||||
|
||||
Each incoming HTTP request is handled by its own thread managed by the HTTP library. The following operations are performed in HTTP worker threads:
|
||||
|
||||
- JSON request parsing
|
||||
- Chat template application
|
||||
- Tokenization
|
||||
- Conversion of `server_task_result` into final JSON response
|
||||
- Error formatting into JSON
|
||||
- Tracking of partial/incremental responses (e.g., streaming tool calls or reasoning steps)
|
||||
|
||||
**Best practices to follow:**
|
||||
|
||||
- All JSON formatting and chat template logic must stay in the HTTP layer.
|
||||
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
|
||||
|
||||
### Example trace of a request
|
||||
|
||||
Here is an example trace of an API request for text completion:
|
||||
|
||||
- A request arrives at the HTTP layer.
|
||||
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
|
||||
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
|
||||
- `server_res_generator` creates a new `task_result_state` for each task:
|
||||
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
|
||||
- `server_task` is moved into `server_queue` inside `server_context`.
|
||||
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
|
||||
- `update_slot()` processes the task as described in the "Batching" section above.
|
||||
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
|
||||
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
|
||||
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
|
||||
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
|
||||
|
||||
### Testing
|
||||
|
||||
`llama-server` includes an automated test suite based on `pytest`.
|
||||
|
||||
The framework automatically starts a `llama-server` instance, sends requests, and validates responses.
|
||||
|
||||
For detailed instructions, see the [test documentation](./tests/README.md).
|
||||
|
||||
### API for tools
|
||||
|
||||
This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
|
||||
|
||||
**GET /tools**
|
||||
|
||||
Get a list of tools, each tool has these fields:
|
||||
- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file`
|
||||
- `display_name` (string): the name to be displayed on UI. Example: `Read file`
|
||||
- `type` (string): always be `"builtin"` for now
|
||||
- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"`
|
||||
- `definition` (object): the OAI-compat definition of this tool
|
||||
|
||||
**POST /tools**
|
||||
|
||||
Invoke a tool call, request body is a JSON object with:
|
||||
- `tool` (string): the name of the tool
|
||||
- `params` (object): a mapping from argument name (string) to argument value
|
||||
|
||||
Returns JSON object. There are two response formats:
|
||||
|
||||
Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"plain_text_response": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
The client should extract this value and place it inside message content (note: content is no longer a JSON), example
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
Format 2: Normal JSON response, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "cannot open this file"
|
||||
}
|
||||
```
|
||||
|
||||
That requires `JSON.stringify` when formatted to message content:
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "{\"error\":\"cannot open this file\"}"
|
||||
}
|
||||
```
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
|
||||
- Parallel decoding support: https://github.com/ggml-org/llama.cpp/pull/3228
|
||||
- Refactor introducing `server_queue` and `server_response`: https://github.com/ggml-org/llama.cpp/pull/5065
|
||||
- Reranking endpoint: https://github.com/ggml-org/llama.cpp/pull/9510
|
||||
- Multimodal model support (`libmtmd`): https://github.com/ggml-org/llama.cpp/pull/12898
|
||||
- Unified KV cache handling: https://github.com/ggml-org/llama.cpp/pull/16736
|
||||
- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
|
||||
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
|
||||
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
|
||||
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808
|
||||
- INI presets: https://github.com/ggml-org/llama.cpp/pull/17859 (+ refactoring: https://github.com/ggml-org/llama.cpp/pull/18169)
|
||||
- Sleeping mode: https://github.com/ggml-org/llama.cpp/pull/18228
|
||||
|
||||
|
||||
|
||||
|
||||
## Web UI
|
||||
|
||||
The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation.
|
||||
|
||||
The SvelteKit-based Web UI is introduced in this PR: https://github.com/ggml-org/llama.cpp/pull/14839
|
||||
|
||||
### Features
|
||||
|
||||
- **Chat interface** with streaming responses
|
||||
- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection
|
||||
- **Modality validation** - ensures selected model supports conversation's attachments (images, audio)
|
||||
- **Conversation management** - branching, regeneration, editing with history preservation
|
||||
- **Attachment support** - images, audio, PDFs (with vision/text fallback)
|
||||
- **Configurable parameters** - temperature, top_p, etc. synced with server defaults
|
||||
- **Dark/light theme**
|
||||
|
||||
### Tech Stack
|
||||
|
||||
- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state
|
||||
- **TailwindCSS** + **shadcn-svelte** - styling and UI components
|
||||
- **Vite** - build tooling
|
||||
- **IndexedDB** (Dexie) - local storage for conversations
|
||||
- **LocalStorage** - user settings persistence
|
||||
|
||||
### Architecture
|
||||
|
||||
The WebUI follows a layered architecture:
|
||||
|
||||
```
|
||||
Routes → Components → Hooks → Stores → Services → Storage/API
|
||||
```
|
||||
|
||||
- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`)
|
||||
- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`)
|
||||
- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`)
|
||||
|
||||
For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/):
|
||||
|
||||
- `high-level-architecture.mmd` - full architecture with all modules
|
||||
- `high-level-architecture-simplified.mmd` - simplified overview
|
||||
- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode
|
||||
- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode
|
||||
- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.)
|
||||
|
||||
### Development
|
||||
|
||||
```sh
|
||||
# make sure you have Node.js installed
|
||||
cd tools/server/webui
|
||||
npm i
|
||||
|
||||
# run dev server (with hot reload)
|
||||
npm run dev
|
||||
|
||||
# run tests
|
||||
npm run test
|
||||
|
||||
# build production bundle
|
||||
npm run build
|
||||
```
|
||||
|
||||
After `public/index.html` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI.
|
||||
|
||||
**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development.
|
||||
1863
tools/server/README.md
Normal file
1863
tools/server/README.md
Normal file
File diff suppressed because it is too large
Load Diff
119
tools/server/bench/README.md
Normal file
119
tools/server/bench/README.md
Normal file
@@ -0,0 +1,119 @@
|
||||
### Server benchmark tools
|
||||
|
||||
Benchmark is using [k6](https://k6.io/).
|
||||
|
||||
##### Install k6 and sse extension
|
||||
|
||||
SSE is not supported by default in k6, you have to build k6 with the [xk6-sse](https://github.com/phymbert/xk6-sse) extension.
|
||||
|
||||
Example (assuming golang >= 1.21 is installed):
|
||||
```shell
|
||||
go install go.k6.io/xk6/cmd/xk6@latest
|
||||
$GOPATH/bin/xk6 build master \
|
||||
--with github.com/phymbert/xk6-sse
|
||||
```
|
||||
|
||||
#### Download a dataset
|
||||
|
||||
This dataset was originally proposed in [vLLM benchmarks](https://github.com/vllm-project/vllm/blob/main/benchmarks/README.md).
|
||||
|
||||
```shell
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
```
|
||||
|
||||
#### Download a model
|
||||
Example for PHI-2
|
||||
|
||||
```shell
|
||||
../../../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
|
||||
```
|
||||
|
||||
#### Start the server
|
||||
The server must answer OAI Chat completion requests on `http://localhost:8080/v1` or according to the environment variable `SERVER_BENCH_URL`.
|
||||
|
||||
Example:
|
||||
```shell
|
||||
llama-server --host localhost --port 8080 \
|
||||
--model ggml-model-q4_0.gguf \
|
||||
--cont-batching \
|
||||
--metrics \
|
||||
--parallel 8 \
|
||||
--batch-size 512 \
|
||||
--ctx-size 4096 \
|
||||
-ngl 33
|
||||
```
|
||||
|
||||
#### Run the benchmark
|
||||
|
||||
For 500 chat completions request with 8 concurrent users during maximum 10 minutes, run:
|
||||
```shell
|
||||
./k6 run script.js --duration 10m --iterations 500 --vus 8
|
||||
```
|
||||
|
||||
The benchmark values can be overridden with:
|
||||
- `SERVER_BENCH_URL` server url prefix for chat completions, default `http://localhost:8080/v1`
|
||||
- `SERVER_BENCH_N_PROMPTS` total prompts to randomly select in the benchmark, default `480`
|
||||
- `SERVER_BENCH_MODEL_ALIAS` model alias to pass in the completion request, default `my-model`
|
||||
- `SERVER_BENCH_MAX_TOKENS` max tokens to predict, default: `512`
|
||||
- `SERVER_BENCH_DATASET` path to the benchmark dataset file
|
||||
- `SERVER_BENCH_MAX_PROMPT_TOKENS` maximum prompt tokens to filter out in the dataset: default `1024`
|
||||
- `SERVER_BENCH_MAX_CONTEXT` maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens, default `2048`
|
||||
|
||||
Note: the local tokenizer is just a string space split, real number of tokens will differ.
|
||||
|
||||
Or with [k6 options](https://k6.io/docs/using-k6/k6-options/reference/):
|
||||
|
||||
```shell
|
||||
SERVER_BENCH_N_PROMPTS=500 k6 run script.js --duration 10m --iterations 500 --vus 8
|
||||
```
|
||||
|
||||
To [debug http request](https://k6.io/docs/using-k6/http-debugging/) use `--http-debug="full"`.
|
||||
|
||||
#### Metrics
|
||||
|
||||
Following metrics are available computed from the OAI chat completions response `usage`:
|
||||
- `llamacpp_tokens_second` Trend of `usage.total_tokens / request duration`
|
||||
- `llamacpp_prompt_tokens` Trend of `usage.prompt_tokens`
|
||||
- `llamacpp_prompt_tokens_total_counter` Counter of `usage.prompt_tokens`
|
||||
- `llamacpp_completion_tokens` Trend of `usage.completion_tokens`
|
||||
- `llamacpp_completion_tokens_total_counter` Counter of `usage.completion_tokens`
|
||||
- `llamacpp_completions_truncated_rate` Rate of completions truncated, i.e. if `finish_reason === 'length'`
|
||||
- `llamacpp_completions_stop_rate` Rate of completions stopped by the model, i.e. if `finish_reason === 'stop'`
|
||||
|
||||
The script will fail if too many completions are truncated, see `llamacpp_completions_truncated_rate`.
|
||||
|
||||
K6 metrics might be compared against [server metrics](../README.md), with:
|
||||
|
||||
```shell
|
||||
curl http://localhost:8080/metrics
|
||||
```
|
||||
|
||||
### Using the CI python script
|
||||
The `bench.py` script does several steps:
|
||||
- start the server
|
||||
- define good variable for k6
|
||||
- run k6 script
|
||||
- extract metrics from prometheus
|
||||
|
||||
It aims to be used in the CI, but you can run it manually:
|
||||
|
||||
```shell
|
||||
LLAMA_SERVER_BIN_PATH=../../../cmake-build-release/bin/llama-server python bench.py \
|
||||
--runner-label local \
|
||||
--name local \
|
||||
--branch `git rev-parse --abbrev-ref HEAD` \
|
||||
--commit `git rev-parse HEAD` \
|
||||
--scenario script.js \
|
||||
--duration 5m \
|
||||
--hf-repo ggml-org/models \
|
||||
--hf-file phi-2/ggml-model-q4_0.gguf \
|
||||
--model-path-prefix models \
|
||||
--parallel 4 \
|
||||
-ngl 33 \
|
||||
--batch-size 2048 \
|
||||
--ubatch-size 256 \
|
||||
--ctx-size 4096 \
|
||||
--n-prompts 200 \
|
||||
--max-prompt-tokens 256 \
|
||||
--max-tokens 256
|
||||
```
|
||||
322
tools/server/bench/bench.py
Normal file
322
tools/server/bench/bench.py
Normal file
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.dates
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def main(args_in: list[str] | None = None) -> None:
|
||||
parser = argparse.ArgumentParser(description="Start server benchmark scenario")
|
||||
parser.add_argument("--name", type=str, help="Bench name", required=True)
|
||||
parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
|
||||
parser.add_argument("--branch", type=str, help="Branch name", default="detached")
|
||||
parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
|
||||
parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, help="Server listen host", default="8080")
|
||||
parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
|
||||
parser.add_argument("--n-prompts", type=int,
|
||||
help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
|
||||
parser.add_argument("--max-prompt-tokens", type=int,
|
||||
help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
|
||||
required=True)
|
||||
parser.add_argument("--max-tokens", type=int,
|
||||
help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
|
||||
required=True)
|
||||
parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
|
||||
parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
|
||||
parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
|
||||
parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
|
||||
parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
|
||||
parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
|
||||
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
|
||||
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
|
||||
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
|
||||
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Start the server and performance scenario
|
||||
try:
|
||||
server_process = start_server(args)
|
||||
except Exception:
|
||||
print("bench: server start error :")
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
sys.exit(1)
|
||||
|
||||
# start the benchmark
|
||||
iterations = 0
|
||||
data = {}
|
||||
try:
|
||||
start_benchmark(args)
|
||||
|
||||
with open("results.github.env", 'w') as github_env:
|
||||
# parse output
|
||||
with open('k6-results.json', 'r') as bench_results:
|
||||
# Load JSON data from file
|
||||
data = json.load(bench_results)
|
||||
for metric_name in data['metrics']:
|
||||
for metric_metric in data['metrics'][metric_name]:
|
||||
value = data['metrics'][metric_name][metric_metric]
|
||||
if isinstance(value, float) or isinstance(value, int):
|
||||
value = round(value, 2)
|
||||
data['metrics'][metric_name][metric_metric]=value
|
||||
github_env.write(
|
||||
f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
|
||||
iterations = data['root_group']['checks']['success completion']['passes']
|
||||
|
||||
except Exception:
|
||||
print("bench: error :")
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
# Stop the server
|
||||
if server_process:
|
||||
try:
|
||||
print(f"bench: shutting down server pid={server_process.pid} ...")
|
||||
if os.name == 'nt':
|
||||
interrupt = signal.CTRL_C_EVENT
|
||||
else:
|
||||
interrupt = signal.SIGINT
|
||||
server_process.send_signal(interrupt)
|
||||
server_process.wait(0.5)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
|
||||
server_process.kill() # SIGKILL
|
||||
server_process.wait()
|
||||
|
||||
while is_server_listening(args.host, args.port):
|
||||
time.sleep(0.1)
|
||||
|
||||
title = (f"llama.cpp {args.name} on {args.runner_label}\n "
|
||||
f"duration={args.duration} {iterations} iterations")
|
||||
xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
|
||||
f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
|
||||
f"branch={args.branch} commit={args.commit}")
|
||||
|
||||
# Prometheus
|
||||
end_time = time.time()
|
||||
prometheus_metrics = {}
|
||||
if is_server_listening("0.0.0.0", 9090):
|
||||
metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
|
||||
'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
|
||||
|
||||
for metric in metrics:
|
||||
resp = requests.get(f"http://localhost:9090/api/v1/query_range",
|
||||
params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
|
||||
|
||||
with open(f"{metric}.json", 'w') as metric_json:
|
||||
metric_json.write(resp.text)
|
||||
|
||||
if resp.status_code != 200:
|
||||
print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
|
||||
else:
|
||||
metric_data = resp.json()
|
||||
values = metric_data['data']['result'][0]['values']
|
||||
timestamps, metric_values = zip(*values)
|
||||
metric_values = [float(value) for value in metric_values]
|
||||
prometheus_metrics[metric] = metric_values
|
||||
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
|
||||
plt.figure(figsize=(16, 10), dpi=80)
|
||||
plt.plot(timestamps_dt, metric_values, label=metric)
|
||||
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
|
||||
plt.yticks(fontsize=12, alpha=.7)
|
||||
|
||||
ylabel = f"llamacpp:{metric}"
|
||||
plt.title(title,
|
||||
fontsize=14, wrap=True)
|
||||
plt.grid(axis='both', alpha=.3)
|
||||
plt.ylabel(ylabel, fontsize=22)
|
||||
plt.xlabel(xlabel, fontsize=14, wrap=True)
|
||||
plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
|
||||
plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
|
||||
plt.gcf().autofmt_xdate()
|
||||
|
||||
# Remove borders
|
||||
plt.gca().spines["top"].set_alpha(0.0)
|
||||
plt.gca().spines["bottom"].set_alpha(0.3)
|
||||
plt.gca().spines["right"].set_alpha(0.0)
|
||||
plt.gca().spines["left"].set_alpha(0.3)
|
||||
|
||||
# Save the plot as a jpg image
|
||||
plt.savefig(f'{metric}.jpg', dpi=60)
|
||||
plt.close()
|
||||
|
||||
# Mermaid format in case images upload failed
|
||||
with open(f"{metric}.mermaid", 'w') as mermaid_f:
|
||||
mermaid = (
|
||||
f"""---
|
||||
config:
|
||||
xyChart:
|
||||
titleFontSize: 12
|
||||
width: 900
|
||||
height: 600
|
||||
themeVariables:
|
||||
xyChart:
|
||||
titleColor: "#000000"
|
||||
---
|
||||
xychart-beta
|
||||
title "{title}"
|
||||
y-axis "llamacpp:{metric}"
|
||||
x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
|
||||
line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
|
||||
""")
|
||||
mermaid_f.write(mermaid)
|
||||
|
||||
# 140 chars max for commit status description
|
||||
bench_results = {
|
||||
"i": iterations,
|
||||
"req": {
|
||||
"p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
|
||||
"avg": round(data['metrics']["http_req_duration"]["avg"], 2),
|
||||
},
|
||||
"pp": {
|
||||
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
|
||||
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
|
||||
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
|
||||
},
|
||||
"tg": {
|
||||
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
|
||||
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
|
||||
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
|
||||
},
|
||||
}
|
||||
with open("results.github.env", 'a') as github_env:
|
||||
github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
|
||||
github_env.write(f"BENCH_ITERATIONS={iterations}\n")
|
||||
|
||||
title = title.replace('\n', ' ')
|
||||
xlabel = xlabel.replace('\n', ' ')
|
||||
github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
|
||||
github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
|
||||
|
||||
|
||||
def start_benchmark(args):
|
||||
k6_path = './k6'
|
||||
if 'BENCH_K6_BIN_PATH' in os.environ:
|
||||
k6_path = os.environ['BENCH_K6_BIN_PATH']
|
||||
k6_args = [
|
||||
'run', args.scenario,
|
||||
'--no-color',
|
||||
'--no-connection-reuse',
|
||||
'--no-vu-connection-reuse',
|
||||
]
|
||||
k6_args.extend(['--duration', args.duration])
|
||||
k6_args.extend(['--iterations', args.n_prompts])
|
||||
k6_args.extend(['--vus', args.parallel])
|
||||
k6_args.extend(['--summary-export', 'k6-results.json'])
|
||||
k6_args.extend(['--out', 'csv=k6-results.csv'])
|
||||
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
|
||||
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
|
||||
print(f"bench: starting k6 with: {args}")
|
||||
k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
|
||||
if k6_completed.returncode != 0:
|
||||
raise Exception("bench: unable to run k6")
|
||||
|
||||
|
||||
def start_server(args):
|
||||
server_process = start_server_background(args)
|
||||
|
||||
attempts = 0
|
||||
max_attempts = 600
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
max_attempts *= 2
|
||||
|
||||
while not is_server_listening(args.host, args.port):
|
||||
attempts += 1
|
||||
if attempts > max_attempts:
|
||||
assert False, "server not started"
|
||||
print(f"bench: waiting for server to start ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
attempts = 0
|
||||
while not is_server_ready(args.host, args.port):
|
||||
attempts += 1
|
||||
if attempts > max_attempts:
|
||||
assert False, "server not ready"
|
||||
print(f"bench: waiting for server to be ready ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
print("bench: server started and ready.")
|
||||
return server_process
|
||||
|
||||
|
||||
def start_server_background(args):
|
||||
# Start the server
|
||||
server_path = '../../../build/bin/llama-server'
|
||||
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
|
||||
server_path = os.environ['LLAMA_SERVER_BIN_PATH']
|
||||
server_args = [
|
||||
'--host', args.host,
|
||||
'--port', args.port,
|
||||
]
|
||||
server_args.extend(['--hf-repo', args.hf_repo])
|
||||
server_args.extend(['--hf-file', args.hf_file])
|
||||
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
|
||||
server_args.extend(['--ctx-size', args.ctx_size])
|
||||
server_args.extend(['--parallel', args.parallel])
|
||||
server_args.extend(['--batch-size', args.batch_size])
|
||||
server_args.extend(['--ubatch-size', args.ubatch_size])
|
||||
server_args.extend(['--n-predict', args.max_tokens * 2])
|
||||
server_args.append('--cont-batching')
|
||||
server_args.append('--metrics')
|
||||
server_args.append('--flash-attn')
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
pkwargs = {
|
||||
'stdout': subprocess.PIPE,
|
||||
'stderr': subprocess.PIPE
|
||||
}
|
||||
server_process = subprocess.Popen(
|
||||
args,
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] # ty: ignore[no-matching-overload]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
print(line.decode('utf-8'), end='', file=out_stream)
|
||||
|
||||
thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
|
||||
thread_stdout.start()
|
||||
thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
|
||||
thread_stderr.start()
|
||||
|
||||
return server_process
|
||||
|
||||
|
||||
def is_server_listening(server_fqdn, server_port):
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((server_fqdn, server_port))
|
||||
_is_server_listening = result == 0
|
||||
if _is_server_listening:
|
||||
print(f"server is listening on {server_fqdn}:{server_port}...")
|
||||
return _is_server_listening
|
||||
|
||||
|
||||
def is_server_ready(server_fqdn, server_port):
|
||||
url = f"http://{server_fqdn}:{server_port}/health"
|
||||
response = requests.get(url)
|
||||
return response.status_code == 200
|
||||
|
||||
|
||||
def escape_metric_name(metric_name):
|
||||
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
9
tools/server/bench/prometheus.yml
Normal file
9
tools/server/bench/prometheus.yml
Normal file
@@ -0,0 +1,9 @@
|
||||
global:
|
||||
scrape_interval: 10s
|
||||
external_labels:
|
||||
llamacpp: 'server'
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'llama.cpp server'
|
||||
static_configs:
|
||||
- targets: ['localhost:8080']
|
||||
2
tools/server/bench/requirements.txt
Normal file
2
tools/server/bench/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
matplotlib
|
||||
requests
|
||||
162
tools/server/bench/script.js
Normal file
162
tools/server/bench/script.js
Normal file
@@ -0,0 +1,162 @@
|
||||
import sse from 'k6/x/sse'
|
||||
import {check, sleep} from 'k6'
|
||||
import {SharedArray} from 'k6/data'
|
||||
import {Counter, Rate, Trend} from 'k6/metrics'
|
||||
import exec from 'k6/execution';
|
||||
|
||||
// Server chat completions prefix
|
||||
const server_url = __ENV.SERVER_BENCH_URL ? __ENV.SERVER_BENCH_URL : 'http://localhost:8080/v1'
|
||||
|
||||
// Number of total prompts in the dataset - default 10m / 10 seconds/request * number of users
|
||||
const n_prompt = __ENV.SERVER_BENCH_N_PROMPTS ? parseInt(__ENV.SERVER_BENCH_N_PROMPTS) : 600 / 10 * 8
|
||||
|
||||
// Model name to request
|
||||
const model = __ENV.SERVER_BENCH_MODEL_ALIAS ? __ENV.SERVER_BENCH_MODEL_ALIAS : 'my-model'
|
||||
|
||||
// Dataset path
|
||||
const dataset_path = __ENV.SERVER_BENCH_DATASET ? __ENV.SERVER_BENCH_DATASET : './ShareGPT_V3_unfiltered_cleaned_split.json'
|
||||
|
||||
// Max tokens to predict
|
||||
const max_tokens = __ENV.SERVER_BENCH_MAX_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_TOKENS) : 512
|
||||
|
||||
// Max prompt tokens
|
||||
const n_prompt_tokens = __ENV.SERVER_BENCH_MAX_PROMPT_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_PROMPT_TOKENS) : 1024
|
||||
|
||||
// Max slot context
|
||||
const n_ctx_slot = __ENV.SERVER_BENCH_MAX_CONTEXT ? parseInt(__ENV.SERVER_BENCH_MAX_CONTEXT) : 2048
|
||||
|
||||
export function setup() {
|
||||
console.info(`Benchmark config: server_url=${server_url} n_prompt=${n_prompt} model=${model} dataset_path=${dataset_path} max_tokens=${max_tokens}`)
|
||||
}
|
||||
|
||||
const data = new SharedArray('conversations', function () {
|
||||
const tokenizer = (message) => message.split(/[\s,'".?]/)
|
||||
|
||||
return JSON.parse(open(dataset_path))
|
||||
// Filter out the conversations with less than 2 turns.
|
||||
.filter(data => data["conversations"].length >= 2)
|
||||
.filter(data => data["conversations"][0]["from"] === "human")
|
||||
.map(data => {
|
||||
return {
|
||||
prompt: data["conversations"][0]["value"],
|
||||
n_prompt_tokens: tokenizer(data["conversations"][0]["value"]).length,
|
||||
n_completion_tokens: tokenizer(data["conversations"][1]["value"]).length,
|
||||
}
|
||||
})
|
||||
// Filter out too short sequences
|
||||
.filter(conv => conv.n_prompt_tokens >= 4 && conv.n_completion_tokens >= 4)
|
||||
// Filter out too long sequences.
|
||||
.filter(conv => conv.n_prompt_tokens <= n_prompt_tokens && conv.n_prompt_tokens + conv.n_completion_tokens <= n_ctx_slot)
|
||||
// Keep only first n prompts
|
||||
.slice(0, n_prompt)
|
||||
})
|
||||
|
||||
const llamacpp_prompt_tokens = new Trend('llamacpp_prompt_tokens')
|
||||
const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
|
||||
|
||||
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
|
||||
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
|
||||
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
|
||||
|
||||
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
|
||||
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
|
||||
|
||||
const llamacpp_completions_truncated_rate = new Rate('llamacpp_completions_truncated_rate')
|
||||
const llamacpp_completions_stop_rate = new Rate('llamacpp_completions_stop_rate')
|
||||
|
||||
export const options = {
|
||||
thresholds: {
|
||||
llamacpp_completions_truncated_rate: [
|
||||
// more than 80% of truncated input will abort the test
|
||||
{threshold: 'rate < 0.8', abortOnFail: true, delayAbortEval: '1m'},
|
||||
],
|
||||
},
|
||||
duration: '10m',
|
||||
vus: 8,
|
||||
}
|
||||
|
||||
export default function () {
|
||||
const conversation = data[exec.scenario.iterationInInstance % data.length]
|
||||
const payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are ChatGPT, an AI assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation.prompt,
|
||||
}
|
||||
],
|
||||
"model": model,
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true, // False to be supported in llama.cpp server
|
||||
},
|
||||
"seed": 42,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
|
||||
}
|
||||
|
||||
const params = {method: 'POST', body: JSON.stringify(payload)};
|
||||
|
||||
const startTime = new Date()
|
||||
let promptEvalEndTime = null
|
||||
let prompt_tokens = 0
|
||||
let completions_tokens = 0
|
||||
let finish_reason = null
|
||||
const res = sse.open(`${server_url}/chat/completions`, params, function (client) {
|
||||
client.on('event', function (event) {
|
||||
if (promptEvalEndTime == null) {
|
||||
promptEvalEndTime = new Date()
|
||||
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
|
||||
}
|
||||
|
||||
if (event.data === '[DONE]' || event.data === '') {
|
||||
return
|
||||
}
|
||||
|
||||
let chunk = JSON.parse(event.data)
|
||||
|
||||
if (chunk.choices && chunk.choices.length > 0) {
|
||||
let choice = chunk.choices[0]
|
||||
if (choice.finish_reason) {
|
||||
finish_reason = choice.finish_reason
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usage) {
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
llamacpp_prompt_tokens.add(prompt_tokens)
|
||||
llamacpp_prompt_tokens_total_counter.add(prompt_tokens)
|
||||
|
||||
completions_tokens = chunk.usage.completion_tokens
|
||||
llamacpp_completion_tokens.add(completions_tokens)
|
||||
llamacpp_completion_tokens_total_counter.add(completions_tokens)
|
||||
}
|
||||
})
|
||||
|
||||
client.on('error', function (e) {
|
||||
console.log('An unexpected error occurred: ', e.error());
|
||||
throw e;
|
||||
})
|
||||
})
|
||||
|
||||
check(res, {'success completion': (r) => r.status === 200})
|
||||
|
||||
const endTime = new Date()
|
||||
|
||||
const promptEvalTime = promptEvalEndTime - startTime
|
||||
if (promptEvalTime > 0) {
|
||||
llamacpp_prompt_processing_second.add(prompt_tokens / (promptEvalEndTime - startTime) * 1.e3)
|
||||
}
|
||||
|
||||
const completion_time = endTime - promptEvalEndTime
|
||||
if (completions_tokens > 0 && completion_time > 0) {
|
||||
llamacpp_tokens_second.add(completions_tokens / completion_time * 1.e3)
|
||||
}
|
||||
llamacpp_completions_truncated_rate.add(finish_reason === 'length')
|
||||
llamacpp_completions_stop_rate.add(finish_reason === 'stop')
|
||||
|
||||
sleep(0.3)
|
||||
}
|
||||
109
tools/server/chat-llama2.sh
Executable file
109
tools/server/chat-llama2.sh
Executable file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
API_URL="${API_URL:-http://127.0.0.1:8080}"
|
||||
|
||||
CHAT=(
|
||||
"Hello, Assistant."
|
||||
"Hello. How may I help you today?"
|
||||
)
|
||||
|
||||
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
|
||||
trim() {
|
||||
shopt -s extglob
|
||||
set -- "${1##+([[:space:]])}"
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
trim_trailing() {
|
||||
shopt -s extglob
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
format_prompt() {
|
||||
if [[ "${#CHAT[@]}" -eq 0 ]]; then
|
||||
echo -n "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>"
|
||||
else
|
||||
LAST_INDEX=$(( ${#CHAT[@]} - 1 ))
|
||||
echo -n "${CHAT[$LAST_INDEX]}\n[INST] $1 [/INST]"
|
||||
fi
|
||||
}
|
||||
|
||||
tokenize() {
|
||||
curl \
|
||||
--silent \
|
||||
--request POST \
|
||||
--url "${API_URL}/tokenize" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
|
||||
| jq '.tokens[]'
|
||||
}
|
||||
|
||||
N_KEEP=$(tokenize "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>" | wc -l)
|
||||
|
||||
chat_completion() {
|
||||
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
|
||||
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
|
||||
prompt: .,
|
||||
temperature: 0.2,
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
n_keep: $n_keep,
|
||||
n_predict: 1024,
|
||||
stop: ["[INST]"],
|
||||
stream: true
|
||||
}')"
|
||||
|
||||
# Create a temporary file to hold the Python output
|
||||
TEMPFILE=$(mktemp)
|
||||
|
||||
exec 3< <(curl \
|
||||
--silent \
|
||||
--no-buffer \
|
||||
--request POST \
|
||||
--url "${API_URL}/completion" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "${DATA}")
|
||||
|
||||
python -c "
|
||||
import json
|
||||
import sys
|
||||
|
||||
answer = ''
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
if line.startswith('data: '):
|
||||
json_content = line[6:].strip()
|
||||
content = json.loads(json_content)['content']
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
answer += content
|
||||
|
||||
answer = answer.rstrip('\n')
|
||||
|
||||
# Write the answer to the temporary file
|
||||
with open('$TEMPFILE', 'w') as f:
|
||||
f.write(answer)
|
||||
" <&3
|
||||
|
||||
exec 3<&-
|
||||
|
||||
# Read the answer from the temporary file
|
||||
ANSWER=$(cat $TEMPFILE)
|
||||
|
||||
# Clean up the temporary file
|
||||
rm $TEMPFILE
|
||||
|
||||
printf "\n"
|
||||
|
||||
CHAT+=("$1" "$(trim "$ANSWER")")
|
||||
}
|
||||
|
||||
while true; do
|
||||
echo -en "\033[0;32m" # Green color
|
||||
read -r -e -p "> " QUESTION
|
||||
echo -en "\033[0m" # Reset color
|
||||
chat_completion "${QUESTION}"
|
||||
done
|
||||
131
tools/server/chat.mjs
Normal file
131
tools/server/chat.mjs
Normal file
@@ -0,0 +1,131 @@
|
||||
import * as readline from 'node:readline'
|
||||
import { stdin, stdout } from 'node:process'
|
||||
import { readFileSync } from 'node:fs'
|
||||
import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs'
|
||||
|
||||
const args = process.argv.slice(2);
|
||||
const grammarJsonSchemaFile = args.find(
|
||||
(_, index) => args[index - 1] === "--grammar-json-schema"
|
||||
);
|
||||
|
||||
const no_cached_prompt = args.find(
|
||||
(_, index) => args[index - 1] === "--no-cache-prompt"
|
||||
) ?? "false";
|
||||
|
||||
const grammarFile = args.find((_, index) => args[index - 1] === "--grammar");
|
||||
|
||||
// Example usage: function,arguments
|
||||
const grammarJsonSchemaPropOrder = args.find(
|
||||
(_, index) => args[index - 1] === "--grammar-json-schema-prop-order"
|
||||
);
|
||||
const propOrder = grammarJsonSchemaPropOrder
|
||||
? grammarJsonSchemaPropOrder
|
||||
.split(",")
|
||||
.reduce((acc, cur, index) => ({ ...acc, [cur]: index }), {})
|
||||
: {};
|
||||
|
||||
let grammar = null
|
||||
if (grammarJsonSchemaFile) {
|
||||
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
|
||||
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
|
||||
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
|
||||
converter.visit(schema, '')
|
||||
grammar = converter.formatGrammar()
|
||||
}
|
||||
if (grammarFile) {
|
||||
grammar = readFileSync(grammarFile, 'utf-8')
|
||||
}
|
||||
|
||||
// for cached prompt
|
||||
let slot_id = -1;
|
||||
|
||||
const API_URL = 'http://127.0.0.1:8080'
|
||||
|
||||
const chat = [
|
||||
{
|
||||
human: "Hello, Assistant.",
|
||||
assistant: "Hello. How may I help you today?"
|
||||
},
|
||||
{
|
||||
human: "Please tell me the largest city in Europe.",
|
||||
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia."
|
||||
},
|
||||
]
|
||||
|
||||
const instruction = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.`
|
||||
|
||||
function format_prompt(question) {
|
||||
return `${instruction}\n${
|
||||
chat.map(m =>`### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
|
||||
}\n### Human: ${question}\n### Assistant:`
|
||||
}
|
||||
|
||||
async function tokenize(content) {
|
||||
const result = await fetch(`${API_URL}/tokenize`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ content })
|
||||
})
|
||||
|
||||
if (!result.ok) {
|
||||
return []
|
||||
}
|
||||
|
||||
return await result.json().tokens
|
||||
}
|
||||
|
||||
const n_keep = await tokenize(instruction).length
|
||||
|
||||
async function chat_completion(question) {
|
||||
const result = await fetch(`${API_URL}/completion`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({
|
||||
prompt: format_prompt(question),
|
||||
temperature: 0.2,
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
n_keep: n_keep,
|
||||
n_predict: 256,
|
||||
cache_prompt: no_cached_prompt === "false",
|
||||
slot_id: slot_id,
|
||||
stop: ["\n### Human:"], // stop completion after generating this
|
||||
grammar,
|
||||
stream: true,
|
||||
})
|
||||
})
|
||||
|
||||
if (!result.ok) {
|
||||
return
|
||||
}
|
||||
|
||||
let answer = ''
|
||||
|
||||
for await (var chunk of result.body) {
|
||||
const t = Buffer.from(chunk).toString('utf8')
|
||||
if (t.startsWith('data: ')) {
|
||||
const message = JSON.parse(t.substring(6))
|
||||
slot_id = message.slot_id
|
||||
answer += message.content
|
||||
process.stdout.write(message.content)
|
||||
if (message.stop) {
|
||||
if (message.truncated) {
|
||||
chat.shift()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
process.stdout.write('\n')
|
||||
chat.push({ human: question, assistant: answer.trimStart() })
|
||||
}
|
||||
|
||||
const rl = readline.createInterface({ input: stdin, output: stdout });
|
||||
|
||||
const readlineQuestion = (rl, query, options) => new Promise((resolve, reject) => {
|
||||
rl.question(query, options, resolve)
|
||||
});
|
||||
|
||||
while(true) {
|
||||
const question = await readlineQuestion(rl, '> ')
|
||||
await chat_completion(question)
|
||||
}
|
||||
80
tools/server/chat.sh
Executable file
80
tools/server/chat.sh
Executable file
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
API_URL="${API_URL:-http://127.0.0.1:8080}"
|
||||
|
||||
CHAT=(
|
||||
"Hello, Assistant."
|
||||
"Hello. How may I help you today?"
|
||||
"Please tell me the largest city in Europe."
|
||||
"Sure. The largest city in Europe is Moscow, the capital of Russia."
|
||||
)
|
||||
|
||||
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
|
||||
trim() {
|
||||
shopt -s extglob
|
||||
set -- "${1##+([[:space:]])}"
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
trim_trailing() {
|
||||
shopt -s extglob
|
||||
printf "%s" "${1%%+([[:space:]])}"
|
||||
}
|
||||
|
||||
format_prompt() {
|
||||
echo -n "${INSTRUCTION}"
|
||||
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
|
||||
}
|
||||
|
||||
tokenize() {
|
||||
curl \
|
||||
--silent \
|
||||
--request POST \
|
||||
--url "${API_URL}/tokenize" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
|
||||
| jq '.tokens[]'
|
||||
}
|
||||
|
||||
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
|
||||
|
||||
chat_completion() {
|
||||
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
|
||||
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
|
||||
prompt: .,
|
||||
temperature: 0.2,
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
n_keep: $n_keep,
|
||||
n_predict: 256,
|
||||
cache_prompt: true,
|
||||
stop: ["\n### Human:"],
|
||||
stream: true
|
||||
}')"
|
||||
|
||||
ANSWER=''
|
||||
|
||||
while IFS= read -r LINE; do
|
||||
if [[ $LINE = data:* ]]; then
|
||||
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
|
||||
printf "%s" "${CONTENT}"
|
||||
ANSWER+="${CONTENT}"
|
||||
fi
|
||||
done < <(curl \
|
||||
--silent \
|
||||
--no-buffer \
|
||||
--request POST \
|
||||
--url "${API_URL}/completion" \
|
||||
--header "Content-Type: application/json" \
|
||||
--data-raw "${DATA}")
|
||||
|
||||
printf "\n"
|
||||
|
||||
CHAT+=("$1" "$(trim "$ANSWER")")
|
||||
}
|
||||
|
||||
while true; do
|
||||
read -r -e -p "> " QUESTION
|
||||
chat_completion "${QUESTION}"
|
||||
done
|
||||
1
tools/server/public/bundle.css
Normal file
1
tools/server/public/bundle.css
Normal file
File diff suppressed because one or more lines are too long
13285
tools/server/public/bundle.js
Normal file
13285
tools/server/public/bundle.js
Normal file
File diff suppressed because it is too large
Load Diff
34
tools/server/public/index.html
Normal file
34
tools/server/public/index.html
Normal file
File diff suppressed because one or more lines are too long
12
tools/server/public/loading.html
Normal file
12
tools/server/public/loading.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="refresh" content="5">
|
||||
</head>
|
||||
<body>
|
||||
<div id="loading">
|
||||
The model is loading. Please wait.<br/>
|
||||
The user interface will appear soon.
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
630
tools/server/server-chat.cpp
Normal file
630
tools/server/server-chat.cpp
Normal file
@@ -0,0 +1,630 @@
|
||||
#include "server-chat.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
json server_chat_convert_responses_to_chatcmpl(const json & response_body) {
|
||||
if (!response_body.contains("input")) {
|
||||
throw std::invalid_argument("'input' is required");
|
||||
}
|
||||
if (!json_value(response_body, "previous_response_id", std::string{}).empty()) {
|
||||
throw std::invalid_argument("llama.cpp does not support 'previous_response_id'.");
|
||||
}
|
||||
|
||||
const json input_value = response_body.at("input");
|
||||
json chatcmpl_body = response_body;
|
||||
chatcmpl_body.erase("input");
|
||||
std::vector<json> chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("instructions")) {
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", json_value(response_body, "instructions", std::string())},
|
||||
});
|
||||
chatcmpl_body.erase("instructions");
|
||||
}
|
||||
|
||||
if (input_value.is_string()) {
|
||||
// #responses_create-input-text_input
|
||||
chatcmpl_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", input_value},
|
||||
});
|
||||
} else if (input_value.is_array()) {
|
||||
// #responses_create-input-input_item_list
|
||||
|
||||
static auto exists_and_is_array = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_array();
|
||||
};
|
||||
static auto exists_and_is_string = [](const json & j, const char * key) -> bool {
|
||||
return j.contains(key) && j.at(key).is_string();
|
||||
};
|
||||
|
||||
for (json item : input_value) {
|
||||
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
|
||||
|
||||
if (exists_and_is_string(item, "content")) {
|
||||
// #responses_create-input-input_item_list-input_message-content-text_input
|
||||
// Only "Input message" contains item["content"]::string
|
||||
// After converting item["content"]::string to item["content"]::array,
|
||||
// we can treat "Input message" as sum of "Item-Input message" and "Item-Output message"
|
||||
item["content"] = json::array({
|
||||
json {
|
||||
{"text", item.at("content")},
|
||||
{"type", "input_text"}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (exists_and_is_array(item, "content") &&
|
||||
exists_and_is_string(item, "role") &&
|
||||
(item.at("role") == "user" ||
|
||||
item.at("role") == "system" ||
|
||||
item.at("role") == "developer")
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-input_message
|
||||
std::vector<json> chatcmpl_content;
|
||||
|
||||
for (const json & input_item : item.at("content")) {
|
||||
const std::string type = json_value(input_item, "type", std::string());
|
||||
|
||||
if (type == "input_text") {
|
||||
if (!input_item.contains("text")) {
|
||||
throw std::invalid_argument("'Input text' requires 'text'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"text", input_item.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (type == "input_image") {
|
||||
// While `detail` is marked as required,
|
||||
// it has default value("auto") and can be omitted.
|
||||
|
||||
if (!input_item.contains("image_url")) {
|
||||
throw std::invalid_argument("'image_url' is required");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"image_url", json {
|
||||
{"url", input_item.at("image_url")}
|
||||
}},
|
||||
{"type", "image_url"},
|
||||
});
|
||||
} else if (type == "input_file") {
|
||||
throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment");
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'");
|
||||
}
|
||||
}
|
||||
|
||||
if (item.contains("type")) {
|
||||
item.erase("type");
|
||||
}
|
||||
if (item.contains("status")) {
|
||||
item.erase("status");
|
||||
}
|
||||
item["content"] = chatcmpl_content;
|
||||
|
||||
chatcmpl_messages.push_back(item);
|
||||
} else if (exists_and_is_string(item, "role") &&
|
||||
item.at("role") == "assistant" &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "message"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-output_message
|
||||
auto chatcmpl_content = json::array();
|
||||
|
||||
// Handle both string content and array content
|
||||
if (item.contains("content") && item.at("content").is_string()) {
|
||||
// String content - convert to text content part
|
||||
chatcmpl_content.push_back({
|
||||
{"text", item.at("content")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (exists_and_is_array(item, "content")) {
|
||||
// Array content - process each item
|
||||
for (const auto & output_text : item.at("content")) {
|
||||
const std::string type = json_value(output_text, "type", std::string());
|
||||
if (type == "output_text" || type == "input_text") {
|
||||
// Accept both output_text and input_text (string content gets converted to input_text)
|
||||
if (!exists_and_is_string(output_text, "text")) {
|
||||
throw std::invalid_argument("'Output text' requires 'text'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"text", output_text.at("text")},
|
||||
{"type", "text"},
|
||||
});
|
||||
} else if (type == "refusal") {
|
||||
if (!exists_and_is_string(output_text, "refusal")) {
|
||||
throw std::invalid_argument("'Refusal' requires 'refusal'");
|
||||
}
|
||||
chatcmpl_content.push_back({
|
||||
{"refusal", output_text.at("refusal")},
|
||||
{"type", "refusal"},
|
||||
});
|
||||
} else {
|
||||
throw std::invalid_argument("'type' must be one of 'output_text' or 'refusal'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "content")) {
|
||||
prev_msg["content"] = json::array();
|
||||
}
|
||||
auto & prev_content = prev_msg["content"];
|
||||
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
|
||||
} else {
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
}
|
||||
} else if (exists_and_is_string(item, "arguments") &&
|
||||
exists_and_is_string(item, "call_id") &&
|
||||
exists_and_is_string(item, "name") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call
|
||||
json tool_call = {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")},
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"},
|
||||
};
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "tool_calls")) {
|
||||
prev_msg["tool_calls"] = json::array();
|
||||
}
|
||||
prev_msg["tool_calls"].push_back(tool_call);
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({tool_call})}
|
||||
});
|
||||
}
|
||||
} else if (exists_and_is_string(item, "call_id") &&
|
||||
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "function_call_output"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call_output
|
||||
if (item.at("output").is_string()) {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", item.at("output")},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
} else {
|
||||
json chatcmpl_outputs = item.at("output");
|
||||
for (json & chatcmpl_output : chatcmpl_outputs) {
|
||||
if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") {
|
||||
throw std::invalid_argument("Output of tool call should be 'Input text'");
|
||||
}
|
||||
chatcmpl_output["type"] = "text";
|
||||
}
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"content", chatcmpl_outputs},
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", item.at("call_id")},
|
||||
});
|
||||
}
|
||||
} else if (exists_and_is_array(item, "summary") &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
item.at("type") == "reasoning") {
|
||||
// #responses_create-input-input_item_list-item-reasoning
|
||||
|
||||
if (!exists_and_is_array(item, "content")) {
|
||||
throw std::invalid_argument("item['content'] is not an array");
|
||||
}
|
||||
if (item.at("content").empty()) {
|
||||
throw std::invalid_argument("item['content'] is empty");
|
||||
}
|
||||
if (!exists_and_is_string(item.at("content")[0], "text")) {
|
||||
throw std::invalid_argument("item['content']['text'] is not a string");
|
||||
}
|
||||
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"reasoning_content", item.at("content")[0].at("text")},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot determine type of 'item'");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("'input' must be a string or array of objects");
|
||||
}
|
||||
|
||||
chatcmpl_body["messages"] = chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("tools")) {
|
||||
if (!response_body.at("tools").is_array()) {
|
||||
throw std::invalid_argument("'tools' must be an array of objects");
|
||||
}
|
||||
std::vector<json> chatcmpl_tools;
|
||||
for (json resp_tool : response_body.at("tools")) {
|
||||
json chatcmpl_tool;
|
||||
|
||||
if (json_value(resp_tool, "type", std::string()) != "function") {
|
||||
throw std::invalid_argument("'type' of tool must be 'function'");
|
||||
}
|
||||
resp_tool.erase("type");
|
||||
chatcmpl_tool["type"] = "function";
|
||||
|
||||
if (!resp_tool.contains("strict")) {
|
||||
resp_tool["strict"] = true;
|
||||
}
|
||||
chatcmpl_tool["function"] = resp_tool;
|
||||
chatcmpl_tools.push_back(chatcmpl_tool);
|
||||
}
|
||||
chatcmpl_body.erase("tools");
|
||||
chatcmpl_body["tools"] = chatcmpl_tools;
|
||||
}
|
||||
|
||||
if (response_body.contains("max_output_tokens")) {
|
||||
chatcmpl_body.erase("max_output_tokens");
|
||||
chatcmpl_body["max_tokens"] = response_body["max_output_tokens"];
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
|
||||
// Edits the cch section of an "x-anthropic-billing-header" system prompt.
|
||||
// Does nothing to any other prompt.
|
||||
//
|
||||
// This is a claude message with a "cch=ef01a" attribute that breaks prefix caching.
|
||||
// The cch stamp is a whitebox end-to-end integrity hint. It's not meaningful as a
|
||||
// system prompt data, particularly to llama.cpp, but its presence means the prefix
|
||||
// cache will not get past it: It changes on each request.
|
||||
//
|
||||
// Reference: https://github.com/ggml-org/llama.cpp/pull/21793
|
||||
// Example header:
|
||||
// ```
|
||||
// x-anthropic-billing-header: cc_version=2.1.101.e51; cc_entrypoint=cli; cch=a5145;You are Claude Code, Anthropic's official CLI for Claude.
|
||||
// ^^^^^
|
||||
// ```
|
||||
static void normalize_anthropic_billing_header(std::string & system_text) {
|
||||
if (system_text.rfind("x-anthropic-billing-header:", 0) != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t header_prefix_length = strlen("x-anthropic-billing-header:");
|
||||
const size_t cch_length = 5;
|
||||
const size_t index_cch = system_text.find("cch=", header_prefix_length);
|
||||
if (index_cch == std::string::npos) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t index_replace = index_cch + 4;
|
||||
if (index_replace + cch_length < system_text.length() && system_text[index_replace + cch_length] == ';') {
|
||||
for (size_t i = 0; i < cch_length; ++i) {
|
||||
system_text[index_replace + i] = 'f';
|
||||
}
|
||||
} else {
|
||||
LOG_ERR("anthropic string not as expected: %s", system_text.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
json server_chat_convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
// Convert system prompt
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
normalize_anthropic_billing_header(system_content);
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
auto system_text = json_value(block, "text", std::string());
|
||||
normalize_anthropic_billing_header(system_text);
|
||||
system_content += system_text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
const json & messages = body.at("messages");
|
||||
if (messages.is_array()) {
|
||||
for (const auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
|
||||
if (!msg.contains("content")) {
|
||||
if (role == "assistant") {
|
||||
continue;
|
||||
}
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
const json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
std::string reasoning_content;
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "thinking") {
|
||||
reasoning_content += json_value(block, "thinking", std::string());
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
std::ostringstream ss;
|
||||
ss << "data:" << media_type << ";base64," << data;
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", ss.str()}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls || !reasoning_content.empty()) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
new_msg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_body["messages"] = oai_messages;
|
||||
|
||||
// Convert tools
|
||||
if (body.contains("tools")) {
|
||||
const json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
json oai_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
oai_body["tools"] = oai_tools;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if (body.contains("tool_choice")) {
|
||||
const json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_body["tool_choice"] = "auto";
|
||||
} else if (type == "any" || type == "tool") {
|
||||
oai_body["tool_choice"] = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert stop_sequences to stop
|
||||
if (body.contains("stop_sequences")) {
|
||||
oai_body["stop"] = body.at("stop_sequences");
|
||||
}
|
||||
|
||||
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (body.contains("max_tokens")) {
|
||||
oai_body["max_tokens"] = body.at("max_tokens");
|
||||
} else {
|
||||
oai_body["max_tokens"] = 4096;
|
||||
}
|
||||
|
||||
// Pass through common params
|
||||
for (const auto & key : {"temperature", "top_p", "top_k", "stream", "chat_template_kwargs"}) {
|
||||
if (body.contains(key)) {
|
||||
oai_body[key] = body.at(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific thinking param
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific metadata param
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
oai_body["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
return oai_body;
|
||||
}
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta["content"] = diff.content_delta;
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
json tool_call;
|
||||
tool_call["index"] = diff.tool_call_index;
|
||||
if (!diff.tool_call_delta.id.empty()) {
|
||||
tool_call["id"] = diff.tool_call_delta.id;
|
||||
tool_call["type"] = "function";
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
|
||||
json function = json::object();
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
function["name"] = diff.tool_call_delta.name;
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
function["arguments"] = diff.tool_call_delta.arguments;
|
||||
}
|
||||
tool_call["function"] = function;
|
||||
}
|
||||
delta["tool_calls"] = json::array({ tool_call });
|
||||
}
|
||||
return delta;
|
||||
}
|
||||
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second.data);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
// handle input data
|
||||
std::string prompt = json_value(inp_body, "prompt", std::string());
|
||||
std::string language = json_value(inp_body, "language", std::string());
|
||||
std::string response_format = json_value(inp_body, "response_format", std::string("json"));
|
||||
if (response_format != "json") {
|
||||
throw std::invalid_argument("Only 'json' response_format is supported for transcription");
|
||||
}
|
||||
const common_chat_prompt_preset preset = common_chat_get_asr_prompt(tmpls);
|
||||
if (prompt.empty()) {
|
||||
prompt = preset.user;
|
||||
}
|
||||
if (!language.empty()) {
|
||||
prompt += string_format(" (language: %s)", language.c_str());
|
||||
}
|
||||
prompt += get_media_marker();
|
||||
|
||||
json messages = json::array();
|
||||
if (!preset.system.empty()) {
|
||||
messages.push_back({{"role", "system"}, {"content", preset.system}});
|
||||
}
|
||||
messages.push_back({{"role", "user"}, {"content", prompt}});
|
||||
|
||||
json chatcmpl_body = inp_body; // copy all fields
|
||||
chatcmpl_body["messages"] = messages;
|
||||
|
||||
// because input from form-data, everything is string, we need to correct the types here
|
||||
std::string stream = json_value(inp_body, "stream", std::string("false"));
|
||||
chatcmpl_body["stream"] = stream == "true";
|
||||
|
||||
if (inp_body.contains("max_tokens")) {
|
||||
std::string inp = inp_body["max_tokens"].get<std::string>();
|
||||
chatcmpl_body["max_tokens"] = std::stoul(inp);
|
||||
}
|
||||
|
||||
if (inp_body.contains("temperature")) {
|
||||
std::string inp = inp_body["temperature"].get<std::string>();
|
||||
chatcmpl_body["temperature"] = std::stof(inp);
|
||||
}
|
||||
|
||||
return chatcmpl_body;
|
||||
}
|
||||
26
tools/server/server-chat.h
Normal file
26
tools/server/server-chat.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Chat conversion functions for server (Responses API, Anthropic API, OAI streaming diffs)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// Convert OpenAI Responses API format to OpenAI Chat Completions API format
|
||||
json server_chat_convert_responses_to_chatcmpl(const json & body);
|
||||
|
||||
// Convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json server_chat_convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// convert OpenAI transcriptions API format to OpenAI Chat Completions API format
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
1586
tools/server/server-common.cpp
Normal file
1586
tools/server/server-common.cpp
Normal file
File diff suppressed because it is too large
Load Diff
373
tools/server/server-common.h
Normal file
373
tools/server/server-common.h
Normal file
@@ -0,0 +1,373 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "chat.h"
|
||||
#include "mtmd.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.at(key);
|
||||
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
|
||||
LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
|
||||
return default_value;
|
||||
}
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
ERROR_TYPE_AUTHENTICATION,
|
||||
ERROR_TYPE_SERVER,
|
||||
ERROR_TYPE_NOT_FOUND,
|
||||
ERROR_TYPE_PERMISSION,
|
||||
ERROR_TYPE_UNAVAILABLE, // custom error
|
||||
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
||||
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
|
||||
};
|
||||
|
||||
// thin wrapper around common_grammar_trigger with (de)serialization functions
|
||||
struct server_grammar_trigger {
|
||||
common_grammar_trigger value;
|
||||
|
||||
server_grammar_trigger() = default;
|
||||
server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
|
||||
server_grammar_trigger(const json & in) {
|
||||
value.type = (common_grammar_trigger_type) in.at("type").get<int>();
|
||||
value.value = in.at("value").get<std::string>();
|
||||
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
value.token = (llama_token) in.at("token").get<int>();
|
||||
}
|
||||
}
|
||||
|
||||
json to_json() const {
|
||||
json out {
|
||||
{"type", (int) value.type},
|
||||
{"value", value.value},
|
||||
};
|
||||
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out["token"] = (int) value.token;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
json format_error_response(const std::string & message, const enum error_type type);
|
||||
|
||||
//
|
||||
// random string / id
|
||||
//
|
||||
|
||||
std::string random_string();
|
||||
std::string gen_chatcmplid();
|
||||
std::string gen_tool_call_id();
|
||||
|
||||
// get a random marker; note: each time the server restarts, the marker will be different
|
||||
const char * get_media_marker();
|
||||
|
||||
//
|
||||
// lora utils
|
||||
//
|
||||
|
||||
// check whether the given lora set has only aloras activated (empty => false)
|
||||
bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras);
|
||||
|
||||
// if the two sets of loras are different, they require a cache clear unless the
|
||||
// change is only from aloras to aloras.
|
||||
bool lora_should_clear_cache(
|
||||
const std::vector<common_adapter_lora_info> & current,
|
||||
const std::vector<common_adapter_lora_info> & next);
|
||||
|
||||
std::map<int, float> parse_lora_request(const json & data);
|
||||
|
||||
bool are_lora_equal(
|
||||
const std::vector<common_adapter_lora_info> & l1,
|
||||
const std::vector<common_adapter_lora_info> & l2);
|
||||
|
||||
// get the ids of all enabled loras
|
||||
std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras);
|
||||
|
||||
//
|
||||
// server_tokens
|
||||
//
|
||||
|
||||
/**
|
||||
* server_tokens is a helper to manage the input tokens and image for the server.
|
||||
* it is made this way to simplify the logic of KV cache management.
|
||||
*/
|
||||
struct server_tokens {
|
||||
bool has_mtmd = false;
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** index in tokens to the image chunk
|
||||
// note: the order need to be in-sync with tokens
|
||||
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||
|
||||
// list of tokens
|
||||
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||
// otherwise, it is a normal text token
|
||||
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
// TODO: server_tokens should be copyable - remove this:
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
// Allow moving (usually implicitly generated if members are movable)
|
||||
server_tokens(server_tokens&&) = default;
|
||||
server_tokens& operator=(server_tokens&&) = default;
|
||||
|
||||
// Allow accessing elements using [] operator
|
||||
llama_token operator[](size_t index) { return tokens[index]; }
|
||||
const llama_token& operator[](size_t index) const { return tokens[index]; }
|
||||
|
||||
server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd);
|
||||
server_tokens(const llama_tokens & tokens, bool has_mtmd);
|
||||
|
||||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
|
||||
llama_pos pos_next(int64_t n_tokens = -1) const;
|
||||
|
||||
// number of tokens with position < max_pos
|
||||
size_t size_up_to_pos(llama_pos max_pos) const;
|
||||
|
||||
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
|
||||
|
||||
void push_back(llama_token tok);
|
||||
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk * chunk);
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens & tokens);
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const llama_tokens & inp_tokens);
|
||||
|
||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||
const llama_tokens & get_tokens() const;
|
||||
|
||||
llama_tokens get_text_tokens() const;
|
||||
|
||||
// for compatibility with speculative decoding
|
||||
void set_token(llama_pos pos, llama_token id);
|
||||
|
||||
size_t size() const { return tokens.size(); }
|
||||
|
||||
bool empty() const { return tokens.empty(); }
|
||||
|
||||
void clear() {
|
||||
map_idx_to_media.clear();
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void keep_first(size_t n);
|
||||
|
||||
std::string detokenize(const llama_context * ctx, bool special) const;
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context * ctx,
|
||||
mtmd_context * mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
bool json_is_array_of_numbers(const json & data);
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
bool json_is_array_of_mixed_numbers_strings(const json & data);
|
||||
|
||||
// does array have any individual integers/tokens?
|
||||
bool json_is_array_and_contains_numbers(const json & data);
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
json json_get_nested_values(const std::vector<std::string> & paths, const json & js);
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
size_t validate_utf8(const std::string& text);
|
||||
|
||||
// process mtmd prompt, return the server_tokens containing both text tokens and media chunks
|
||||
server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files);
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
|
||||
*/
|
||||
std::vector<server_tokens> tokenize_input_prompts(
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx,
|
||||
const json & json_prompt,
|
||||
bool add_special,
|
||||
bool parse_special);
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
|
||||
// global server parameters for chat formatting / parsing
|
||||
struct server_chat_params {
|
||||
bool use_jinja;
|
||||
bool prefill_assistant;
|
||||
common_reasoning_format reasoning_format;
|
||||
std::map<std::string, std::string> chat_template_kwargs; // mapping key --> json value
|
||||
common_chat_templates_ptr tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
std::string media_path;
|
||||
bool force_pure_content = false;
|
||||
};
|
||||
|
||||
// used by /completions endpoint
|
||||
json oaicompat_completion_params_parse(const json & body);
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
json oaicompat_chat_params_parse(
|
||||
json & body, /* openai api json semantics */
|
||||
const server_chat_params & opt,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_embeddings_response_oaicompat(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & embeddings,
|
||||
bool use_base64 = false);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_response_rerank(
|
||||
const json & request,
|
||||
const std::string & model_name,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts,
|
||||
int top_n);
|
||||
|
||||
//
|
||||
// other utils
|
||||
//
|
||||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
|
||||
|
||||
std::string safe_json_to_str(const json & data);
|
||||
|
||||
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
|
||||
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
|
||||
|
||||
// format incomplete utf-8 multibyte character for output
|
||||
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);
|
||||
|
||||
// format server-sent event (SSE), return the formatted string to send
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
std::string format_oai_sse(const json & data);
|
||||
|
||||
std::string format_oai_resp_sse(const json & data);
|
||||
|
||||
// format Anthropic-style SSE with event types
|
||||
std::string format_anthropic_sse(const json & data);
|
||||
|
||||
bool is_valid_utf8(const std::string & str);
|
||||
|
||||
//
|
||||
// formatting output responses
|
||||
// TODO: move these to server-task.cpp
|
||||
//
|
||||
|
||||
llama_tokens format_prompt_infill(
|
||||
const llama_vocab * vocab,
|
||||
const json & input_prefix,
|
||||
const json & input_suffix,
|
||||
const json & input_extra,
|
||||
const int n_batch,
|
||||
const int n_predict,
|
||||
const int n_ctx,
|
||||
const bool spm_infill,
|
||||
const llama_tokens & tokens_prompt);
|
||||
|
||||
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
|
||||
server_tokens format_prompt_rerank(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
mtmd_context * mctx,
|
||||
const std::string & query,
|
||||
const std::string & doc);
|
||||
4485
tools/server/server-context.cpp
Normal file
4485
tools/server/server-context.cpp
Normal file
File diff suppressed because it is too large
Load Diff
150
tools/server/server-context.h
Normal file
150
tools/server/server-context.h
Normal file
@@ -0,0 +1,150 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
struct server_context_impl; // private implementation
|
||||
|
||||
struct server_context_meta {
|
||||
std::string build_info;
|
||||
std::string model_name;
|
||||
std::set<std::string> model_aliases;
|
||||
std::set<std::string> model_tags;
|
||||
std::string model_path;
|
||||
bool has_mtmd;
|
||||
bool has_inp_image;
|
||||
bool has_inp_audio;
|
||||
json json_webui_settings;
|
||||
int slot_n_ctx;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
// chat params
|
||||
server_chat_params & chat_params;
|
||||
std::map<std::string, bool> chat_template_caps;
|
||||
|
||||
// tokens
|
||||
std::string bos_token_str;
|
||||
std::string eos_token_str;
|
||||
llama_token fim_pre_token;
|
||||
llama_token fim_sub_token;
|
||||
llama_token fim_mid_token;
|
||||
llama_token fim_pad_token;
|
||||
llama_token fim_rep_token;
|
||||
llama_token fim_sep_token;
|
||||
|
||||
// sampling
|
||||
std::vector<llama_logit_bias> logit_bias_eog;
|
||||
|
||||
// model meta
|
||||
enum llama_vocab_type model_vocab_type;
|
||||
int32_t model_vocab_n_tokens;
|
||||
int32_t model_n_ctx_train;
|
||||
int32_t model_n_embd_inp;
|
||||
uint64_t model_n_params;
|
||||
uint64_t model_size;
|
||||
};
|
||||
|
||||
struct server_context {
|
||||
std::unique_ptr<server_context_impl> impl;
|
||||
|
||||
server_context();
|
||||
~server_context();
|
||||
|
||||
// load the model and initialize llama_context
|
||||
// returns true on success
|
||||
bool load_model(common_params & params);
|
||||
|
||||
// this function will block main thread until termination
|
||||
void start_loop();
|
||||
|
||||
// terminate main loop (will unblock start_loop)
|
||||
void terminate();
|
||||
|
||||
// get the underlaying llama_context, can return nullptr if sleeping
|
||||
// not thread-safe, should only be used from the main thread
|
||||
llama_context * get_llama_context() const;
|
||||
|
||||
// get a new response reader, used by CLI application
|
||||
server_response_reader get_response_reader();
|
||||
|
||||
// get server metadata (read-only), can only be called after load_model()
|
||||
// not thread-safe, should only be used from the main thread
|
||||
server_context_meta get_meta() const;
|
||||
|
||||
// register a callback to be called when sleeping state changes
|
||||
// must be set before load_model() is called
|
||||
void on_sleeping_changed(std::function<void(bool)> callback);
|
||||
};
|
||||
|
||||
|
||||
// forward declarations
|
||||
struct server_res_generator;
|
||||
|
||||
struct server_routes {
|
||||
server_routes(const common_params & params, server_context & ctx_server);
|
||||
|
||||
void init_routes();
|
||||
|
||||
// note: this is not thread-safe and can only when ctx_http.is_ready is false
|
||||
void update_meta(const server_context & ctx_server) {
|
||||
this->meta = std::make_unique<server_context_meta>(ctx_server.get_meta());
|
||||
}
|
||||
|
||||
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
||||
// they won't be called until ctx_http.is_ready is set to true
|
||||
server_http_context::handler_t get_health;
|
||||
server_http_context::handler_t get_metrics;
|
||||
server_http_context::handler_t get_slots;
|
||||
server_http_context::handler_t post_slots;
|
||||
server_http_context::handler_t get_props;
|
||||
server_http_context::handler_t post_props;
|
||||
server_http_context::handler_t post_infill;
|
||||
server_http_context::handler_t post_completions;
|
||||
server_http_context::handler_t post_completions_oai;
|
||||
server_http_context::handler_t post_chat_completions;
|
||||
server_http_context::handler_t post_responses_oai;
|
||||
server_http_context::handler_t post_transcriptions_oai;
|
||||
server_http_context::handler_t post_anthropic_messages;
|
||||
server_http_context::handler_t post_anthropic_count_tokens;
|
||||
server_http_context::handler_t post_apply_template;
|
||||
server_http_context::handler_t get_models;
|
||||
server_http_context::handler_t post_tokenize;
|
||||
server_http_context::handler_t post_detokenize;
|
||||
server_http_context::handler_t post_embeddings;
|
||||
server_http_context::handler_t post_embeddings_oai;
|
||||
server_http_context::handler_t post_rerank;
|
||||
server_http_context::handler_t get_lora_adapters;
|
||||
server_http_context::handler_t post_lora_adapters;
|
||||
|
||||
// to be used in router mode
|
||||
json get_model_info() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
const server_http_req & req,
|
||||
server_task_type type,
|
||||
const json & data,
|
||||
const std::vector<raw_buffer> & files,
|
||||
task_response_type res_type);
|
||||
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
|
||||
|
||||
// using unique_ptr to allow late initialization of const
|
||||
std::unique_ptr<const server_context_meta> meta;
|
||||
|
||||
const common_params & params;
|
||||
const server_context_impl & ctx_server;
|
||||
|
||||
server_queue & queue_tasks;
|
||||
server_response & queue_results;
|
||||
std::unique_ptr<server_res_generator> create_response(bool bypass_sleep = false);
|
||||
};
|
||||
67
tools/server/server-cors-proxy.h
Normal file
67
tools/server/server-cors-proxy.h
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "http.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
||||
std::string target_url = req.get_param("url");
|
||||
common_http_url parsed_url = common_http_parse_url(target_url);
|
||||
|
||||
if (parsed_url.host.empty()) {
|
||||
throw std::runtime_error("invalid target URL: missing host");
|
||||
}
|
||||
|
||||
if (parsed_url.path.empty()) {
|
||||
parsed_url.path = "/";
|
||||
}
|
||||
|
||||
if (!parsed_url.password.empty()) {
|
||||
throw std::runtime_error("authentication in target URL is not supported");
|
||||
}
|
||||
|
||||
if (parsed_url.scheme != "http" && parsed_url.scheme != "https") {
|
||||
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
|
||||
}
|
||||
|
||||
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||
|
||||
std::map<std::string, std::string> headers;
|
||||
for (auto [key, value] : req.headers) {
|
||||
auto new_key = key;
|
||||
if (string_starts_with(new_key, "x-proxy-header-")) {
|
||||
string_replace_all(new_key, "x-proxy-header-", "");
|
||||
}
|
||||
headers[new_key] = value;
|
||||
}
|
||||
|
||||
auto proxy = std::make_unique<server_http_proxy>(
|
||||
method,
|
||||
parsed_url.scheme,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
parsed_url.path,
|
||||
headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
600, // timeout_read (default to 10 minutes)
|
||||
600 // timeout_write (default to 10 minutes)
|
||||
);
|
||||
|
||||
return proxy;
|
||||
}
|
||||
|
||||
static server_http_context::handler_t proxy_handler_post = [](const server_http_req & req) -> server_http_res_ptr {
|
||||
return proxy_request(req, "POST");
|
||||
};
|
||||
|
||||
static server_http_context::handler_t proxy_handler_get = [](const server_http_req & req) -> server_http_res_ptr {
|
||||
return proxy_request(req, "GET");
|
||||
};
|
||||
700
tools/server/server-http.cpp
Normal file
700
tools/server/server-http.cpp
Normal file
@@ -0,0 +1,700 @@
|
||||
#include "common.h"
|
||||
#include "server-http.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// auto generated files (see README.md for details)
|
||||
#include "index.html.hpp"
|
||||
#include "bundle.js.hpp"
|
||||
#include "bundle.css.hpp"
|
||||
#include "loading.html.hpp"
|
||||
#endif
|
||||
|
||||
//
|
||||
// HTTP implementation using cpp-httplib
|
||||
//
|
||||
|
||||
class server_http_context::Impl {
|
||||
public:
|
||||
std::unique_ptr<httplib::Server> srv;
|
||||
};
|
||||
|
||||
server_http_context::server_http_context()
|
||||
: pimpl(std::make_unique<server_http_context::Impl>())
|
||||
{}
|
||||
|
||||
server_http_context::~server_http_context() = default;
|
||||
|
||||
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
||||
// skip logging requests that are regularly sent, to avoid log spam
|
||||
if (req.path == "/health"
|
||||
|| req.path == "/v1/health"
|
||||
|| req.path == "/models"
|
||||
|| req.path == "/v1/models"
|
||||
|| req.path == "/props"
|
||||
|| req.path == "/metrics"
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
||||
|
||||
SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
|
||||
SRV_DBG("request: %s\n", req.body.c_str());
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
}
|
||||
|
||||
// For Google Cloud Platform deployment compatibility
|
||||
struct gcp_params {
|
||||
bool enabled;
|
||||
std::string path_health;
|
||||
std::string path_predict;
|
||||
int port;
|
||||
|
||||
// Ref: https://docs.cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
|
||||
gcp_params() {
|
||||
enabled = getenv("AIP_MODE", "") == "PREDICTION";
|
||||
path_health = getenv("AIP_HEALTH_ROUTE", "", true); // default: using the route defined in server.cpp
|
||||
path_predict = getenv("AIP_PREDICT_ROUTE", "/predict", true);
|
||||
port = std::stoi(getenv("AIP_HTTP_PORT", "8080"));
|
||||
}
|
||||
|
||||
static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) {
|
||||
const char * value = std::getenv(name);
|
||||
if (value == nullptr || value[0] == '\0') {
|
||||
return default_value;
|
||||
}
|
||||
std::string val = value;
|
||||
if (ensure_leading_slash && !val.empty() && val[0] != '/') {
|
||||
val.insert(val.begin(), '/');
|
||||
}
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
bool server_http_context::init(const common_params & params) {
|
||||
const gcp_params gcp;
|
||||
|
||||
path_prefix = params.api_prefix;
|
||||
port = params.port;
|
||||
hostname = params.hostname;
|
||||
|
||||
if (gcp.enabled) {
|
||||
LOG_INF("%s: Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", __func__, gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
|
||||
|
||||
if (port != gcp.port) {
|
||||
LOG_WRN("%s: Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", __func__, port, gcp.port);
|
||||
}
|
||||
|
||||
port = gcp.port;
|
||||
}
|
||||
|
||||
auto & srv = pimpl->srv;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
||||
srv.reset(
|
||||
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
||||
);
|
||||
} else {
|
||||
LOG_INF("Running without SSL\n");
|
||||
srv.reset(new httplib::Server());
|
||||
}
|
||||
#else
|
||||
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||
LOG_ERR("Server is built without SSL support\n");
|
||||
return false;
|
||||
}
|
||||
srv.reset(new httplib::Server());
|
||||
#endif
|
||||
|
||||
srv->set_default_headers({{"Server", "llama.cpp"}});
|
||||
srv->set_logger(log_server_request);
|
||||
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
|
||||
// this is fail-safe; exceptions should already handled by `ex_wrapper`
|
||||
|
||||
std::string message;
|
||||
try {
|
||||
std::rethrow_exception(ep);
|
||||
} catch (const std::exception & e) {
|
||||
message = e.what();
|
||||
} catch (...) {
|
||||
message = "Unknown Exception";
|
||||
}
|
||||
|
||||
res.status = 500;
|
||||
res.set_content(message, "text/plain");
|
||||
LOG_ERR("got exception: %s\n", message.c_str());
|
||||
});
|
||||
|
||||
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res.set_content(
|
||||
safe_json_to_str(json {
|
||||
{"error", {
|
||||
{"message", "File Not Found"},
|
||||
{"type", "not_found_error"},
|
||||
{"code", 404}
|
||||
}}
|
||||
}),
|
||||
"application/json; charset=utf-8"
|
||||
);
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res->error()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
srv->set_read_timeout (params.timeout_read);
|
||||
srv->set_write_timeout(params.timeout_write);
|
||||
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
|
||||
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
|
||||
if (reuse_port) {
|
||||
#ifdef SO_REUSEPORT
|
||||
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1);
|
||||
#else
|
||||
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
|
||||
#endif
|
||||
}
|
||||
});
|
||||
|
||||
if (params.api_keys.size() == 1) {
|
||||
auto key = params.api_keys[0];
|
||||
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
|
||||
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
|
||||
} else if (params.api_keys.size() > 1) {
|
||||
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
|
||||
}
|
||||
|
||||
//
|
||||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
|
||||
static const std::unordered_set<std::string> public_endpoints = {
|
||||
"/health",
|
||||
"/v1/health",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/",
|
||||
"/index.html",
|
||||
"/bundle.js",
|
||||
"/bundle.css",
|
||||
};
|
||||
|
||||
// If API key is not set, skip validation
|
||||
if (api_keys.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If path is public or static file, skip validation
|
||||
if (public_endpoints.find(req.path) != public_endpoints.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the Authorization header
|
||||
std::string req_api_key = req.get_header_value("Authorization");
|
||||
if (req_api_key.empty()) {
|
||||
// retry with anthropic header
|
||||
req_api_key = req.get_header_value("X-Api-Key");
|
||||
}
|
||||
|
||||
// remove the "Bearer " prefix if needed
|
||||
std::string prefix = "Bearer ";
|
||||
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
||||
req_api_key = req_api_key.substr(prefix.size());
|
||||
}
|
||||
|
||||
// validate the API key
|
||||
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res.status = 401;
|
||||
res.set_content(
|
||||
safe_json_to_str(json {
|
||||
{"error", {
|
||||
{"message", "Invalid API Key"},
|
||||
{"type", "authentication_error"},
|
||||
{"code", 401}
|
||||
}}
|
||||
}),
|
||||
"application/json; charset=utf-8"
|
||||
);
|
||||
|
||||
LOG_WRN("Unauthorized: Invalid API Key\n");
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||
bool ready = is_ready.load();
|
||||
if (!ready) {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
if (req.path == "/" || tmp.back() == "html") {
|
||||
res.status = 503;
|
||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// no endpoints is allowed to be accessed when the server is not ready
|
||||
// this is to prevent any data races or inconsistent states
|
||||
res.status = 503;
|
||||
res.set_content(
|
||||
safe_json_to_str(json {
|
||||
{"error", {
|
||||
{"message", "Loading model"},
|
||||
{"type", "unavailable_error"},
|
||||
{"code", 503}
|
||||
}}
|
||||
}),
|
||||
"application/json; charset=utf-8"
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// register server middlewares
|
||||
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
||||
if (req.method == "OPTIONS") {
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
res.set_content("", "text/html"); // blank response, no data
|
||||
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
||||
}
|
||||
if (!middleware_server_state(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
int n_threads_http = params.n_threads_http;
|
||||
if (n_threads_http < 1) {
|
||||
// +4 threads for monitoring, health and some threads reserved for MCP and other tasks in the future
|
||||
n_threads_http = std::max(params.n_parallel + 4, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
}
|
||||
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
|
||||
srv->new_task_queue = [n_threads_http] {
|
||||
// spawn n_threads_http fixed thread (always alive), while allow up to 1024 max possible additional threads
|
||||
// when n_threads_http is used, server will create new "dynamic" threads that will be destroyed after processing each request
|
||||
// ref: https://github.com/yhirose/cpp-httplib/pull/2368
|
||||
size_t max_threads = (size_t)n_threads_http + 1024;
|
||||
return new httplib::ThreadPool(n_threads_http, max_threads);
|
||||
};
|
||||
|
||||
//
|
||||
// Web UI setup
|
||||
//
|
||||
|
||||
if (!params.webui) {
|
||||
LOG_INF("Web UI is disabled\n");
|
||||
} else {
|
||||
// register static assets routes
|
||||
if (!params.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
|
||||
if (!is_found) {
|
||||
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// using embedded static index.html
|
||||
srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) {
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
||||
res.set_content(reinterpret_cast<const char*>(index_html), index_html_len, "text/html; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
srv->Get(params.api_prefix + "/bundle.js", [](const httplib::Request & /*req*/, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(bundle_js), bundle_js_len, "application/javascript; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
srv->Get(params.api_prefix + "/bundle.css", [](const httplib::Request & /*req*/, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(bundle_css), bundle_css_len, "text/css; charset=utf-8");
|
||||
return false;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool server_http_context::start() {
|
||||
// Bind and listen
|
||||
|
||||
auto & srv = pimpl->srv;
|
||||
bool was_bound = false;
|
||||
bool is_sock = false;
|
||||
if (string_ends_with(std::string(hostname), ".sock")) {
|
||||
is_sock = true;
|
||||
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
|
||||
srv->set_address_family(AF_UNIX);
|
||||
// bind_to_port requires a second arg, any value other than 0 should
|
||||
// simply get ignored
|
||||
was_bound = srv->bind_to_port(hostname, 8080);
|
||||
} else {
|
||||
LOG_INF("%s: binding port with default address family\n", __func__);
|
||||
// bind HTTP listen port
|
||||
if (port == 0) {
|
||||
int bound_port = srv->bind_to_any_port(hostname);
|
||||
was_bound = (bound_port >= 0);
|
||||
if (was_bound) {
|
||||
port = bound_port;
|
||||
}
|
||||
} else {
|
||||
was_bound = srv->bind_to_port(hostname, port);
|
||||
}
|
||||
}
|
||||
|
||||
if (!was_bound) {
|
||||
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
|
||||
return false;
|
||||
}
|
||||
|
||||
// run the HTTP server in a thread
|
||||
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
|
||||
srv->wait_until_ready();
|
||||
|
||||
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
|
||||
: string_format("http://%s:%d", hostname.c_str(), port);
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_http_context::stop() const {
|
||||
if (pimpl->srv) {
|
||||
pimpl->srv->stop();
|
||||
}
|
||||
}
|
||||
|
||||
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
|
||||
for (const auto & [key, value] : headers) {
|
||||
res.set_header(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
|
||||
std::map<std::string, std::string> params;
|
||||
for (const auto & [key, value] : req.params) {
|
||||
params[key] = value;
|
||||
}
|
||||
for (const auto & [key, value] : req.path_params) {
|
||||
params[key] = value;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
|
||||
std::map<std::string, std::string> headers;
|
||||
for (const auto & [key, value] : req.headers) {
|
||||
headers[key] = value;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
static std::string build_query_string(const httplib::Request & req) {
|
||||
std::string qs;
|
||||
for (const auto & [key, value] : req.params) {
|
||||
if (!qs.empty()) {
|
||||
qs += '&';
|
||||
}
|
||||
qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value);
|
||||
}
|
||||
return qs;
|
||||
}
|
||||
|
||||
// using unique_ptr for request to allow safe capturing in lambdas
|
||||
using server_http_req_ptr = std::unique_ptr<server_http_req>;
|
||||
|
||||
static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
|
||||
if (response->is_stream()) {
|
||||
res.status = response->status;
|
||||
set_headers(res, response->headers);
|
||||
std::string content_type = response->content_type;
|
||||
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||
std::shared_ptr<server_http_req> q_ptr = std::move(request);
|
||||
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||
std::string chunk;
|
||||
bool has_next = response->next(chunk);
|
||||
if (!chunk.empty()) {
|
||||
if (!sink.write(chunk.data(), chunk.size())) {
|
||||
return false;
|
||||
}
|
||||
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
|
||||
}
|
||||
if (!has_next) {
|
||||
sink.done();
|
||||
SRV_DBG("%s", "http: stream ended\n");
|
||||
}
|
||||
return has_next;
|
||||
};
|
||||
const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
|
||||
response.reset(); // trigger the destruction of the response object
|
||||
request.reset(); // trigger the destruction of the request object
|
||||
};
|
||||
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
||||
} else {
|
||||
res.status = response->status;
|
||||
set_headers(res, response->headers);
|
||||
res.set_content(response->data, response->content_type);
|
||||
}
|
||||
}
|
||||
|
||||
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
get_headers(req),
|
||||
req.path,
|
||||
build_query_string(req),
|
||||
req.body,
|
||||
{},
|
||||
req.is_connection_closed
|
||||
});
|
||||
server_http_res_ptr response = handler(*request);
|
||||
process_handler_response(std::move(request), response, res);
|
||||
});
|
||||
}
|
||||
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, uploaded_file> files;
|
||||
|
||||
if (req.is_multipart_form_data()) {
|
||||
// translate text fields to a JSON object and use it as the body
|
||||
json form_json = json::object();
|
||||
for (const auto & [key, field] : req.form.fields) {
|
||||
if (form_json.contains(key)) {
|
||||
// if the key already exists, convert it to an array
|
||||
if (!form_json[key].is_array()) {
|
||||
json existing_value = form_json[key];
|
||||
form_json[key] = json::array({existing_value});
|
||||
}
|
||||
form_json[key].push_back(field.content);
|
||||
} else {
|
||||
form_json[key] = field.content;
|
||||
}
|
||||
}
|
||||
body = form_json.dump();
|
||||
|
||||
// populate files from multipart form
|
||||
for (const auto & [key, file] : req.form.files) {
|
||||
files[key] = uploaded_file{
|
||||
raw_buffer(file.content.begin(), file.content.end()),
|
||||
file.filename,
|
||||
file.content_type,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
get_headers(req),
|
||||
req.path,
|
||||
build_query_string(req),
|
||||
body,
|
||||
std::move(files),
|
||||
req.is_connection_closed
|
||||
});
|
||||
server_http_res_ptr response = handler(*request);
|
||||
process_handler_response(std::move(request), response, res);
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
|
||||
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
|
||||
//
|
||||
|
||||
// Derives the camelCase @requestFormat alias for a registered path.
|
||||
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
|
||||
static std::string path_to_gcp_format(const std::string & path) {
|
||||
std::string s = path;
|
||||
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
|
||||
s = s.substr(3);
|
||||
}
|
||||
if (!s.empty() && s[0] == '/') {
|
||||
s = s.substr(1);
|
||||
}
|
||||
std::string result;
|
||||
bool cap = false;
|
||||
for (unsigned char c : s) {
|
||||
if (c == ':') break; // stop before path parameters
|
||||
if (c == '/' || c == '-' || c == '_') {
|
||||
cap = true;
|
||||
} else {
|
||||
result += cap ? (char)std::toupper(c) : (char)c;
|
||||
cap = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static json parse_gcp_predict_response(const server_http_res_ptr & res) {
|
||||
if (res == nullptr) {
|
||||
throw std::runtime_error("empty response from internal handler");
|
||||
}
|
||||
if (res->is_stream()) {
|
||||
throw std::invalid_argument("predict route does not support streaming responses");
|
||||
}
|
||||
if (res->data.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
return json::parse(res->data);
|
||||
} catch (...) {
|
||||
return res->data;
|
||||
}
|
||||
}
|
||||
|
||||
void server_http_context::register_gcp_compat() {
|
||||
const gcp_params gcp;
|
||||
|
||||
if (!gcp.enabled) {
|
||||
// do nothing
|
||||
return;
|
||||
}
|
||||
|
||||
if (handlers.count(gcp.path_predict)) {
|
||||
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, gcp.path_predict.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// camelCase alias -> canonical path (first registration wins on collision)
|
||||
// e.g. "chatCompletions" -> "/v1/chat/completions"
|
||||
std::unordered_map<std::string, std::string> alias_to_path;
|
||||
for (const auto & [path, _] : handlers) {
|
||||
alias_to_path.emplace(path_to_gcp_format(path), path);
|
||||
}
|
||||
|
||||
if (!gcp.path_health.empty()) {
|
||||
auto health_handler = handlers.find("/health");
|
||||
GGML_ASSERT(health_handler != handlers.end());
|
||||
get(gcp.path_health, health_handler->second);
|
||||
}
|
||||
|
||||
post(gcp.path_predict, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
|
||||
static const auto build_error = [](const std::string & message, error_type type) -> json {
|
||||
return json {{"error", format_error_response(message, type)}};
|
||||
};
|
||||
|
||||
json data;
|
||||
try {
|
||||
data = json::parse(req.body);
|
||||
} catch (const std::exception & e) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
if (!data.is_object()) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("request body must be a JSON object", ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
if (!data.contains("instances") || !data.at("instances").is_array()) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("request body must include an array field named instances", ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
|
||||
const json & instances = data.at("instances");
|
||||
static const size_t MAX_INSTANCES = 128;
|
||||
if (instances.size() > MAX_INSTANCES) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("instances array exceeds maximum size of " + std::to_string(MAX_INSTANCES), ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<std::future<json>> futures;
|
||||
futures.reserve(instances.size());
|
||||
|
||||
for (const auto & instance : instances) {
|
||||
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
|
||||
if (!instance.is_object()) {
|
||||
return build_error("each instance must be a JSON object", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
|
||||
return build_error("each instance must include a string @requestFormat", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
json payload = instance;
|
||||
const std::string format = payload.at("@requestFormat").get<std::string>();
|
||||
payload.erase("@requestFormat");
|
||||
|
||||
if (payload.contains("stream")) {
|
||||
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
|
||||
payload["stream"] = false;
|
||||
}
|
||||
|
||||
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
|
||||
std::string dispatch_path;
|
||||
auto it_alias = alias_to_path.find(format);
|
||||
if (it_alias != alias_to_path.end()) {
|
||||
dispatch_path = it_alias->second;
|
||||
} else if (handlers.count(format)) {
|
||||
dispatch_path = format;
|
||||
} else {
|
||||
return build_error("no handler registered for @requestFormat: " + format, ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
|
||||
const server_http_req internal_req {
|
||||
req.params,
|
||||
req.headers,
|
||||
path_prefix + dispatch_path,
|
||||
req.query_string,
|
||||
payload.dump(),
|
||||
{},
|
||||
req.should_stop,
|
||||
};
|
||||
|
||||
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
|
||||
return parse_gcp_predict_response(internal_res);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
return build_error(e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||
} catch (const std::exception & e) {
|
||||
return build_error(e.what(), ERROR_TYPE_SERVER);
|
||||
} catch (...) {
|
||||
return build_error("unknown error", ERROR_TYPE_SERVER);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
json predictions = json::array();
|
||||
for (auto & future : futures) {
|
||||
predictions.push_back(future.get());
|
||||
}
|
||||
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->data = safe_json_to_str({{"predictions", predictions}});
|
||||
return res;
|
||||
});
|
||||
}
|
||||
94
tools/server/server-http.h
Normal file
94
tools/server/server-http.h
Normal file
@@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
struct common_params;
|
||||
|
||||
// generator-like API for HTTP response generation
|
||||
// this object response with one of the 2 modes:
|
||||
// 1) normal response: `data` contains the full response body
|
||||
// 2) streaming response: each call to next(output) generates the next chunk
|
||||
// when next(output) returns false, no more data after the current chunk
|
||||
// note: some chunks can be empty, in which case no data is sent for that chunk
|
||||
struct server_http_res {
|
||||
std::string content_type = "application/json; charset=utf-8";
|
||||
int status = 200;
|
||||
std::string data;
|
||||
std::map<std::string, std::string> headers;
|
||||
|
||||
// TODO: move this to a virtual function once we have proper polymorphism support
|
||||
std::function<bool(std::string &)> next = nullptr;
|
||||
bool is_stream() const {
|
||||
return next != nullptr;
|
||||
}
|
||||
|
||||
virtual ~server_http_res() = default;
|
||||
};
|
||||
|
||||
// unique pointer, used by set_chunked_content_provider
|
||||
// httplib requires the stream provider to be stored in heap
|
||||
using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
struct uploaded_file {
|
||||
raw_buffer data;
|
||||
std::string filename;
|
||||
std::string content_type;
|
||||
};
|
||||
|
||||
struct server_http_req {
|
||||
std::map<std::string, std::string> params; // path_params + query_params
|
||||
std::map<std::string, std::string> headers; // used by MCP proxy
|
||||
std::string path;
|
||||
std::string query_string; // query parameters string (e.g. "action=save")
|
||||
std::string body;
|
||||
std::map<std::string, uploaded_file> files; // used for file uploads (form data)
|
||||
const std::function<bool()> & should_stop;
|
||||
|
||||
std::string get_param(const std::string & key, const std::string & def = "") const {
|
||||
auto it = params.find(key);
|
||||
if (it != params.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return def;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_http_context {
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> pimpl;
|
||||
|
||||
std::thread thread; // server thread
|
||||
std::atomic<bool> is_ready = false;
|
||||
|
||||
// note: the handler should never throw exceptions
|
||||
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
|
||||
mutable std::unordered_map<std::string, handler_t> handlers;
|
||||
|
||||
std::string path_prefix;
|
||||
std::string hostname;
|
||||
int port;
|
||||
|
||||
server_http_context();
|
||||
~server_http_context();
|
||||
|
||||
bool init(const common_params & params);
|
||||
bool start();
|
||||
void stop() const;
|
||||
|
||||
void get(const std::string & path, const handler_t & handler) const;
|
||||
void post(const std::string & path, const handler_t & handler) const;
|
||||
|
||||
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
|
||||
// Must be called AFTER all other API routes are registered
|
||||
void register_gcp_compat();
|
||||
|
||||
// for debugging
|
||||
std::string listening_address;
|
||||
};
|
||||
1565
tools/server/server-models.cpp
Normal file
1565
tools/server/server-models.cpp
Normal file
File diff suppressed because it is too large
Load Diff
233
tools/server/server-models.h
Normal file
233
tools/server/server-models.h
Normal file
@@ -0,0 +1,233 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "preset.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
/**
|
||||
* state diagram:
|
||||
*
|
||||
* UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING
|
||||
* ▲ │ │ ▲
|
||||
* └───failed───┘ │ │
|
||||
* ▲ └──sleeping─────┘
|
||||
* └────────unloaded─────────┘
|
||||
*/
|
||||
enum server_model_status {
|
||||
// TODO: also add downloading state when the logic is added
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
SERVER_MODEL_STATUS_LOADING,
|
||||
SERVER_MODEL_STATUS_LOADED,
|
||||
SERVER_MODEL_STATUS_SLEEPING
|
||||
};
|
||||
|
||||
static server_model_status server_model_status_from_string(const std::string & status_str) {
|
||||
if (status_str == "unloaded") {
|
||||
return SERVER_MODEL_STATUS_UNLOADED;
|
||||
}
|
||||
if (status_str == "loading") {
|
||||
return SERVER_MODEL_STATUS_LOADING;
|
||||
}
|
||||
if (status_str == "loaded") {
|
||||
return SERVER_MODEL_STATUS_LOADED;
|
||||
}
|
||||
if (status_str == "sleeping") {
|
||||
return SERVER_MODEL_STATUS_SLEEPING;
|
||||
}
|
||||
throw std::runtime_error("invalid server model status");
|
||||
}
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
|
||||
case SERVER_MODEL_STATUS_LOADING: return "loading";
|
||||
case SERVER_MODEL_STATUS_LOADED: return "loaded";
|
||||
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
struct server_model_meta {
|
||||
common_preset preset;
|
||||
std::string name;
|
||||
std::set<std::string> aliases; // additional names that resolve to this model
|
||||
std::set<std::string> tags; // informational tags, not used for routing
|
||||
int port = 0;
|
||||
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
|
||||
int64_t last_used = 0; // for LRU unloading
|
||||
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
|
||||
json loaded_info; // info to be reflected via /v1/models endpoint
|
||||
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
|
||||
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
|
||||
|
||||
bool is_ready() const {
|
||||
return status == SERVER_MODEL_STATUS_LOADED;
|
||||
}
|
||||
|
||||
bool is_running() const {
|
||||
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING || status == SERVER_MODEL_STATUS_SLEEPING;
|
||||
}
|
||||
|
||||
bool is_failed() const {
|
||||
return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
|
||||
}
|
||||
|
||||
void update_args(common_preset_context & ctx_presets, std::string bin_path);
|
||||
};
|
||||
|
||||
struct subprocess_s;
|
||||
|
||||
struct server_models {
|
||||
private:
|
||||
struct instance_t {
|
||||
std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
};
|
||||
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
std::map<std::string, instance_t> mapping;
|
||||
|
||||
// for stopping models
|
||||
std::condition_variable cv_stop;
|
||||
std::set<std::string> stopping_models;
|
||||
|
||||
// set to true while load_models() is executing a reload; load() will wait until clear
|
||||
bool is_reloading = false;
|
||||
|
||||
common_preset_context ctx_preset;
|
||||
|
||||
common_params base_params;
|
||||
std::string bin_path;
|
||||
std::vector<std::string> base_env;
|
||||
common_preset base_preset; // base preset from llama-server CLI args
|
||||
|
||||
void update_meta(const std::string & name, const server_model_meta & meta);
|
||||
|
||||
// unload least recently used models if the limit is reached
|
||||
void unload_lru();
|
||||
|
||||
// not thread-safe, caller must hold mutex
|
||||
void add_model(server_model_meta && meta);
|
||||
|
||||
public:
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
// (re-)load the list of models from various sources and prepare the metadata mapping
|
||||
// - if this is called the first time, simply populate the metadata
|
||||
// - if this is called subsequently (e.g. when refreshing from disk):
|
||||
// - if a model is running but updated or removed from the source, it will be unloaded
|
||||
// - if a model is not running, it will be added or updated according to the source
|
||||
void load_models();
|
||||
|
||||
// check if a model instance exists (thread-safe)
|
||||
bool has_model(const std::string & name);
|
||||
|
||||
// return a copy of model metadata (thread-safe)
|
||||
std::optional<server_model_meta> get_meta(const std::string & name);
|
||||
|
||||
// return a copy of all model metadata (thread-safe)
|
||||
std::vector<server_model_meta> get_all_meta();
|
||||
|
||||
// load and unload model instances
|
||||
// these functions are thread-safe
|
||||
void load(const std::string & name);
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// update the status of a model instance (thread-safe)
|
||||
void update_status(const std::string & name, server_model_status status, int exit_code);
|
||||
void update_loaded_info(const std::string & name, std::string & raw_info);
|
||||
|
||||
// wait until the model instance is fully loaded (thread-safe)
|
||||
// return when the model no longer in "loading" state
|
||||
void wait_until_loading_finished(const std::string & name);
|
||||
|
||||
// ensure the model is in ready state (thread-safe)
|
||||
// return false if model is ready
|
||||
// otherwise, load the model and blocking wait until it's ready, then return true (meta may need to be refreshed)
|
||||
bool ensure_model_ready(const std::string & name);
|
||||
|
||||
// proxy an HTTP request to the model instance
|
||||
server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
|
||||
|
||||
// return true if the current process is a child server instance
|
||||
static bool is_child_server();
|
||||
|
||||
// notify the router server that a model instance is ready
|
||||
// return the monitoring thread (to be joined by the caller)
|
||||
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler, const json & model_info);
|
||||
|
||||
// notify the router server that the sleeping state has changed
|
||||
static void notify_router_sleeping_state(bool sleeping);
|
||||
};
|
||||
|
||||
struct server_models_routes {
|
||||
common_params params;
|
||||
json webui_settings = json::object();
|
||||
server_models models;
|
||||
server_models_routes(const common_params & params, int argc, char ** argv)
|
||||
: params(params), models(params, argc, argv) {
|
||||
if (!this->params.webui_config_json.empty()) {
|
||||
try {
|
||||
webui_settings = json::parse(this->params.webui_config_json);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
init_routes();
|
||||
}
|
||||
|
||||
void init_routes();
|
||||
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
||||
server_http_context::handler_t get_router_props;
|
||||
server_http_context::handler_t proxy_get;
|
||||
server_http_context::handler_t proxy_post;
|
||||
server_http_context::handler_t get_router_models;
|
||||
server_http_context::handler_t post_router_models_load;
|
||||
server_http_context::handler_t post_router_models_unload;
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple HTTP proxy that forwards requests to another server
|
||||
* and relays the responses back.
|
||||
*/
|
||||
struct server_http_proxy : server_http_res {
|
||||
std::function<void()> cleanup = nullptr;
|
||||
public:
|
||||
server_http_proxy(const std::string & method,
|
||||
const std::string & scheme,
|
||||
const std::string & host,
|
||||
int port,
|
||||
const std::string & path,
|
||||
const std::map<std::string, std::string> & headers,
|
||||
const std::string & body,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::function<bool()> should_stop,
|
||||
int32_t timeout_read,
|
||||
int32_t timeout_write
|
||||
);
|
||||
~server_http_proxy() {
|
||||
if (cleanup) {
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::thread thread;
|
||||
struct msg_t {
|
||||
std::map<std::string, std::string> headers;
|
||||
int status = 0;
|
||||
std::string data;
|
||||
std::string content_type;
|
||||
};
|
||||
};
|
||||
451
tools/server/server-queue.cpp
Normal file
451
tools/server/server-queue.cpp
Normal file
@@ -0,0 +1,451 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// server_queue
|
||||
//
|
||||
|
||||
int server_queue::post(server_task && task, bool front) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
GGML_ASSERT(task.id != -1);
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
const int task_id = task.id;
|
||||
QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
} else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
time_last_task = ggml_time_ms();
|
||||
condition_tasks.notify_one();
|
||||
return task_id;
|
||||
}
|
||||
|
||||
int server_queue::post(std::vector<server_task> && tasks, bool front) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
for (auto & task : tasks) {
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
}
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
} else {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
}
|
||||
time_last_task = ggml_time_ms();
|
||||
condition_tasks.notify_one();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void server_queue::defer(server_task && task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
QUE_DBG("defer task, id = %d\n", task.id);
|
||||
queue_tasks_deferred.push_back(std::move(task));
|
||||
time_last_task = ggml_time_ms();
|
||||
condition_tasks.notify_one();
|
||||
}
|
||||
|
||||
int server_queue::get_new_id() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
int new_id = id++;
|
||||
return new_id;
|
||||
}
|
||||
|
||||
void server_queue::pop_deferred_task(int id_slot) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!queue_tasks_deferred.empty()) {
|
||||
// try to find a task that uses the specified slot
|
||||
bool found = false;
|
||||
for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
|
||||
if (it->id_slot == id_slot) {
|
||||
QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
|
||||
queue_tasks.emplace_front(std::move(*it));
|
||||
queue_tasks_deferred.erase(it);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// if not tasks found using the slot, just pop the first deferred task (default behavior)
|
||||
if (!found) {
|
||||
QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
|
||||
queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
|
||||
queue_tasks_deferred.pop_front();
|
||||
}
|
||||
}
|
||||
time_last_task = ggml_time_ms();
|
||||
condition_tasks.notify_one();
|
||||
}
|
||||
|
||||
void server_queue::wait_until_no_sleep() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!sleeping) {
|
||||
return;
|
||||
} else {
|
||||
if (!req_stop_sleeping) {
|
||||
QUE_DBG("%s", "requesting to stop sleeping\n");
|
||||
req_stop_sleeping = true;
|
||||
condition_tasks.notify_one(); // only main thread is waiting on this
|
||||
}
|
||||
QUE_DBG("%s", "waiting until no sleep\n");
|
||||
condition_tasks.wait(lock, [&]{
|
||||
return !sleeping;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void server_queue::terminate() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
running = false;
|
||||
condition_tasks.notify_all();
|
||||
}
|
||||
|
||||
void server_queue::start_loop(int64_t idle_sleep_ms) {
|
||||
running = true;
|
||||
time_last_task = ggml_time_ms();
|
||||
|
||||
constexpr auto max_wait_time = std::chrono::seconds(1);
|
||||
auto should_sleep = [&]() -> bool {
|
||||
// caller must hold mutex_tasks
|
||||
if (idle_sleep_ms < 0) {
|
||||
return false;
|
||||
}
|
||||
int64_t now = ggml_time_ms();
|
||||
return (now - time_last_task) >= idle_sleep_ms;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
QUE_DBG("%s", "processing new tasks\n");
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
QUE_DBG("%s", "terminate\n");
|
||||
return;
|
||||
}
|
||||
if (queue_tasks.empty()) {
|
||||
lock.unlock();
|
||||
break;
|
||||
}
|
||||
server_task task = std::move(queue_tasks.front());
|
||||
queue_tasks.pop_front();
|
||||
lock.unlock();
|
||||
|
||||
QUE_DBG("processing task, id = %d\n", task.id);
|
||||
callback_new_task(std::move(task));
|
||||
}
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
QUE_DBG("%s", "update slots\n");
|
||||
|
||||
// this will run the main inference process for all slots
|
||||
callback_update_slots();
|
||||
{
|
||||
// update_slots() may take a while to finish, we need to make sure it's not counted as idle
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
time_last_task = ggml_time_ms();
|
||||
}
|
||||
|
||||
QUE_DBG("%s", "waiting for new tasks\n");
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running || !queue_tasks.empty()) {
|
||||
break; // go back to process new tasks or terminate
|
||||
}
|
||||
|
||||
// no tasks, check for sleeping state
|
||||
if (should_sleep()) {
|
||||
QUE_INF("%s", "entering sleeping state\n");
|
||||
sleeping = true;
|
||||
callback_sleeping_state(true);
|
||||
req_stop_sleeping = false;
|
||||
// wait until we are requested to exit sleeping state
|
||||
condition_tasks.wait(lock, [&]{
|
||||
return (!running || req_stop_sleeping);
|
||||
});
|
||||
if (!running) { // may changed during sleep
|
||||
break; // terminate
|
||||
}
|
||||
QUE_INF("%s", "exiting sleeping state\n");
|
||||
req_stop_sleeping = false;
|
||||
callback_sleeping_state(false);
|
||||
sleeping = false;
|
||||
time_last_task = ggml_time_ms();
|
||||
condition_tasks.notify_all(); // notify wait_until_no_sleep()
|
||||
break; // process new tasks
|
||||
} else {
|
||||
// wait for new tasks or timeout for checking sleeping condition
|
||||
bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
if (res) {
|
||||
break; // new task arrived or terminate
|
||||
}
|
||||
// otherwise, loop again to check sleeping condition
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_queue::cleanup_pending_task(int id_target) {
|
||||
// no need lock because this is called exclusively by post()
|
||||
auto rm_func = [id_target](const server_task & task) {
|
||||
return task.id == id_target;
|
||||
};
|
||||
queue_tasks.erase(
|
||||
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
||||
queue_tasks.end());
|
||||
queue_tasks_deferred.erase(
|
||||
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
|
||||
queue_tasks_deferred.end());
|
||||
}
|
||||
|
||||
//
|
||||
// server_response
|
||||
//
|
||||
|
||||
void server_response::add_waiting_task_id(int id_task) {
|
||||
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
|
||||
void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (const auto & id_task : id_tasks) {
|
||||
RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
|
||||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
}
|
||||
|
||||
void server_response::remove_waiting_task_id(int id_task) {
|
||||
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.erase(id_task);
|
||||
// make sure to clean up all pending results
|
||||
queue_results.erase(
|
||||
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
|
||||
return res->id == id_task;
|
||||
}),
|
||||
queue_results.end());
|
||||
}
|
||||
|
||||
void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (const auto & id_task : id_tasks) {
|
||||
RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
||||
waiting_task_ids.erase(id_task);
|
||||
}
|
||||
}
|
||||
|
||||
server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&]{
|
||||
if (!running) {
|
||||
RES_DBG("%s : queue result stop\n", "recv");
|
||||
std::terminate(); // we cannot return here since the caller is HTTP code
|
||||
}
|
||||
return !queue_results.empty();
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
}
|
||||
|
||||
server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
|
||||
if (!running) {
|
||||
RES_DBG("%s : queue result stop\n", __func__);
|
||||
std::terminate(); // we cannot return here since the caller is HTTP code
|
||||
}
|
||||
if (cr_res == std::cv_status::timeout) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
}
|
||||
|
||||
server_task_result_ptr server_response::recv(int id_task) {
|
||||
std::unordered_set<int> id_tasks = {id_task};
|
||||
return recv(id_tasks);
|
||||
}
|
||||
|
||||
void server_response::send(server_task_result_ptr && result) {
|
||||
RES_DBG("sending result for task id = %d\n", result->id);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto & id_task : waiting_task_ids) {
|
||||
if (result->id == id_task) {
|
||||
RES_DBG("task id = %d pushed to result queue\n", result->id);
|
||||
|
||||
queue_results.emplace_back(std::move(result));
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_response::terminate() {
|
||||
running = false;
|
||||
condition_results.notify_all();
|
||||
}
|
||||
|
||||
//
|
||||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::post_task(server_task && task, bool front) {
|
||||
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
|
||||
GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
|
||||
task.index = 0;
|
||||
id_tasks.insert(task.id);
|
||||
states.push_back(task.create_state());
|
||||
queue_results.add_waiting_task_id(task.id);
|
||||
queue_tasks.post(std::move(task), front);
|
||||
}
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
|
||||
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
states.reserve(tasks.size());
|
||||
size_t index = 0;
|
||||
for (auto & task : tasks) {
|
||||
task.index = index++;
|
||||
states.push_back(task.create_state());
|
||||
// for child tasks
|
||||
for (auto & child_task : task.child_tasks) {
|
||||
child_task.index = index++;
|
||||
states.push_back(child_task.create_state());
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(states.size() == id_tasks.size());
|
||||
queue_results.add_waiting_task_ids(id_tasks);
|
||||
queue_tasks.post(std::move(tasks), front);
|
||||
}
|
||||
|
||||
bool server_response_reader::has_next() const {
|
||||
return !cancelled && received_count < id_tasks.size();
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
|
||||
while (true) {
|
||||
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
SRV_WRN("%s", "stopping wait for next result due to should_stop condition (adjust the --timeout argument if needed)\n");
|
||||
SRV_WRN("%s", "ref: https://github.com/ggml-org/llama.cpp/pull/22907\n");
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
if (result->is_error()) {
|
||||
stop(); // cancel remaining tasks
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (!states.empty()) {
|
||||
// update the generation state if needed
|
||||
const size_t idx = result->index;
|
||||
GGML_ASSERT(idx < states.size());
|
||||
result->update(states[idx]);
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// should not reach here
|
||||
}
|
||||
|
||||
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
|
||||
batch_response batch_res;
|
||||
batch_res.results.clear();
|
||||
batch_res.results.resize(id_tasks.size());
|
||||
while (has_next()) {
|
||||
auto res = next(should_stop);
|
||||
if (res == nullptr) {
|
||||
batch_res.is_terminated = true;
|
||||
return batch_res;
|
||||
}
|
||||
if (res->is_error()) {
|
||||
batch_res.error = std::move(res);
|
||||
return batch_res;
|
||||
}
|
||||
const size_t idx = res->index;
|
||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||
batch_res.results[idx] = std::move(res);
|
||||
}
|
||||
return batch_res;
|
||||
}
|
||||
|
||||
void server_response_reader::stop() {
|
||||
queue_results.remove_waiting_task_ids(id_tasks);
|
||||
if (has_next() && !cancelled) {
|
||||
// if tasks is not finished yet, cancel them
|
||||
cancelled = true;
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto & id_task : id_tasks) {
|
||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(std::move(cancel_tasks), true);
|
||||
} else {
|
||||
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||
}
|
||||
}
|
||||
205
tools/server/server-queue.h
Normal file
205
tools/server/server-queue.h
Normal file
@@ -0,0 +1,205 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-task.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
// struct for managing server tasks
|
||||
// in most cases, use server_response_reader to post new tasks and retrieve results
|
||||
struct server_queue {
|
||||
private:
|
||||
int id = 0;
|
||||
bool running = false;
|
||||
bool sleeping = false;
|
||||
bool req_stop_sleeping = false;
|
||||
int64_t time_last_task = 0;
|
||||
|
||||
// queues
|
||||
std::deque<server_task> queue_tasks;
|
||||
std::deque<server_task> queue_tasks_deferred;
|
||||
|
||||
std::mutex mutex_tasks;
|
||||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task &&)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
std::function<void(bool)> callback_sleeping_state;
|
||||
|
||||
public:
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task && task, bool front = false);
|
||||
|
||||
// multi-task version of post()
|
||||
int post(std::vector<server_task> && tasks, bool front = false);
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task && task);
|
||||
|
||||
// Get the next id for creating a new task
|
||||
int get_new_id();
|
||||
|
||||
// Call when the state of one slot is changed, it will move one task from deferred to main queue
|
||||
// prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
|
||||
void pop_deferred_task(int id_slot);
|
||||
|
||||
// if sleeping, request exiting sleep state and wait until it is done
|
||||
// returns immediately if not sleeping
|
||||
void wait_until_no_sleep();
|
||||
|
||||
bool is_sleeping() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
return sleeping;
|
||||
}
|
||||
|
||||
// end the start_loop routine
|
||||
void terminate();
|
||||
|
||||
/**
|
||||
* Main loop consists of these steps:
|
||||
* - Wait until a new task arrives
|
||||
* - Process the task (i.e. maybe copy data into slot)
|
||||
* - Check if multitask is finished
|
||||
* - Update all slots
|
||||
*
|
||||
* Sleeping procedure (disabled if idle_sleep_ms < 0):
|
||||
* - If there is no task after idle_sleep_ms, enter sleeping state
|
||||
* - Call callback_sleeping_state(true)
|
||||
* - Wait until req_stop_sleeping is set to true
|
||||
* - Call callback_sleeping_state(false)
|
||||
* - Exit sleeping state
|
||||
*/
|
||||
void start_loop(int64_t idle_sleep_ms = -1);
|
||||
|
||||
// for metrics
|
||||
size_t queue_tasks_deferred_size() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
return queue_tasks_deferred.size();
|
||||
}
|
||||
|
||||
//
|
||||
// Functions below are not thread-safe, must only be used before start_loop() is called
|
||||
//
|
||||
|
||||
// Register function to process a new task
|
||||
void on_new_task(std::function<void(server_task &&)> callback) {
|
||||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_update_slots(std::function<void(void)> callback) {
|
||||
callback_update_slots = std::move(callback);
|
||||
}
|
||||
|
||||
// Register callback for sleeping state change; multiple callbacks are allowed
|
||||
// note: when entering sleeping state, the callback is called AFTER sleeping is set to true
|
||||
// when leaving sleeping state, the callback is called BEFORE sleeping is set to false
|
||||
void on_sleeping_state(std::function<void(bool)> callback) {
|
||||
if (callback_sleeping_state) {
|
||||
auto prev_callback = std::move(callback_sleeping_state);
|
||||
callback_sleeping_state = [prev_callback, callback](bool sleeping) {
|
||||
prev_callback(sleeping);
|
||||
callback(sleeping);
|
||||
};
|
||||
} else {
|
||||
callback_sleeping_state = std::move(callback);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void cleanup_pending_task(int id_target);
|
||||
};
|
||||
|
||||
// struct for managing server responses
|
||||
// in most cases, use server_response_reader to retrieve results
|
||||
struct server_response {
|
||||
private:
|
||||
bool running = true;
|
||||
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::unordered_set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue (using ptr for polymorphism)
|
||||
std::vector<server_task_result_ptr> queue_results;
|
||||
|
||||
std::mutex mutex_results;
|
||||
std::condition_variable condition_results;
|
||||
|
||||
public:
|
||||
// add the id_task to the list of tasks waiting for response
|
||||
void add_waiting_task_id(int id_task);
|
||||
|
||||
void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
|
||||
|
||||
// when the request is finished, we can remove task associated with it
|
||||
void remove_waiting_task_id(int id_task);
|
||||
|
||||
// remove multiple tasks from waiting list
|
||||
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
|
||||
|
||||
// This function blocks the thread until there is a response for one of the id_tasks
|
||||
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
|
||||
|
||||
// same as recv(), but have timeout in seconds
|
||||
// if timeout is reached, nullptr is returned
|
||||
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
|
||||
|
||||
// single-task version of recv()
|
||||
server_task_result_ptr recv(int id_task);
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result_ptr && result);
|
||||
|
||||
// terminate the waiting loop
|
||||
void terminate();
|
||||
};
|
||||
|
||||
// utility class to make working with server_queue and server_response easier
|
||||
// it provides a generator-like API for server responses
|
||||
// support pooling connection state and aggregating multiple results
|
||||
struct server_response_reader {
|
||||
std::unordered_set<int> id_tasks;
|
||||
server_queue & queue_tasks;
|
||||
server_response & queue_results;
|
||||
size_t received_count = 0;
|
||||
bool cancelled = false;
|
||||
int polling_interval_seconds;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
// only used by streaming completions
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
|
||||
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
int get_new_id() {
|
||||
return queue_tasks.get_new_id();
|
||||
}
|
||||
|
||||
// if front = true, the task will be posted to the front of the queue (high priority)
|
||||
void post_task(server_task && task, bool front = false);
|
||||
void post_tasks(std::vector<server_task> && tasks, bool front = false);
|
||||
bool has_next() const;
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()> & should_stop);
|
||||
|
||||
struct batch_response {
|
||||
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||
std::vector<server_task_result_ptr> results;
|
||||
server_task_result_ptr error; // nullptr if no error
|
||||
};
|
||||
// aggregate multiple results
|
||||
batch_response wait_for_all(const std::function<bool()> & should_stop);
|
||||
|
||||
void stop();
|
||||
};
|
||||
2155
tools/server/server-task.cpp
Normal file
2155
tools/server/server-task.cpp
Normal file
File diff suppressed because it is too large
Load Diff
632
tools/server/server-task.h
Normal file
632
tools/server/server-task.h
Normal file
@@ -0,0 +1,632 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
// TODO: prevent including the whole server-common.h as we only use server_tokens
|
||||
#include "server-common.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_EMBEDDING,
|
||||
SERVER_TASK_TYPE_RERANK,
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_GET_LORA,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
|
||||
enum task_response_type {
|
||||
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT,
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL,
|
||||
TASK_RESPONSE_TYPE_OAI_RESP,
|
||||
TASK_RESPONSE_TYPE_OAI_ASR, // transcriptions API
|
||||
TASK_RESPONSE_TYPE_OAI_EMBD,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
enum stop_type {
|
||||
STOP_TYPE_NONE,
|
||||
STOP_TYPE_EOS,
|
||||
STOP_TYPE_WORD,
|
||||
STOP_TYPE_LIMIT,
|
||||
};
|
||||
|
||||
struct task_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
bool return_tokens = false;
|
||||
bool return_progress = false;
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::map<int, float> lora; // mapping adapter ID -> scale
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
|
||||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
|
||||
json to_json(bool only_metrics = false) const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_parser_params chat_parser_params;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
std::unordered_set<size_t> sent_tool_call_names;
|
||||
|
||||
// for OpenAI Responses and Anthropic streaming API:
|
||||
// track output item / content block state across chunks
|
||||
bool thinking_block_started = false;
|
||||
bool text_block_started = false;
|
||||
|
||||
// for OpenAI Responses streaming API
|
||||
const std::string oai_resp_id;
|
||||
const std::string oai_resp_reasoning_id;
|
||||
const std::string oai_resp_message_id;
|
||||
std::string oai_resp_fc_id; // function call ID for current args delta
|
||||
|
||||
task_result_state(const common_chat_parser_params & chat_parser_params)
|
||||
: chat_parser_params(chat_parser_params)
|
||||
, oai_resp_id("resp_" + random_string())
|
||||
, oai_resp_reasoning_id("rs_" + random_string())
|
||||
, oai_resp_message_id("msg_" + random_string()) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs,
|
||||
bool filter_tool_calls = false);
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
|
||||
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
|
||||
size_t index = 0; // used when there are multiple prompts (batch request)
|
||||
|
||||
// used by SERVER_TASK_TYPE_CANCEL
|
||||
int id_target = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
int id_parent = -1;
|
||||
// temporary store of child tasks for scheduling
|
||||
// note: accessing to elements is invalid after the task is moved to server_slot
|
||||
std::vector<server_task> child_tasks;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
server_tokens tokens;
|
||||
|
||||
// only used by CLI, this allow tokenizing CLI inputs on server side
|
||||
// we need this because mtmd_context and vocab are not accessible outside of server_context
|
||||
bool cli = false;
|
||||
std::string cli_prompt;
|
||||
std::vector<raw_buffer> cli_files;
|
||||
|
||||
server_task_type type;
|
||||
|
||||
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
|
||||
struct slot_action {
|
||||
int id_slot;
|
||||
std::string filename;
|
||||
std::string filepath;
|
||||
};
|
||||
slot_action slot_action;
|
||||
|
||||
// used by SERVER_TASK_TYPE_METRICS
|
||||
bool metrics_reset_bucket = false;
|
||||
|
||||
// used by SERVER_TASK_TYPE_SET_LORA
|
||||
std::map<int, float> set_lora; // mapping adapter ID -> scale
|
||||
|
||||
server_task() = default;
|
||||
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
int32_t n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
bool need_embd() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool need_logits() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool need_sampling() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static task_params params_from_json_cmpl(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data);
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||
std::unordered_set<int> ids(tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
ids.insert(tasks[i].id);
|
||||
for (auto & child : tasks[i].child_tasks) {
|
||||
ids.insert(child.id);
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
void add_child(int id_parent, int id_child) {
|
||||
server_task copy;
|
||||
|
||||
copy.id = id_child;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
copy.id_slot = -1; // child tasks cannot specify slot
|
||||
|
||||
// use different sampling seed for each child
|
||||
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
|
||||
if (copy.params.sampling.seed != LLAMA_DEFAULT_SEED) {
|
||||
copy.params.sampling.seed += (uint32_t)child_tasks.size() + 1;
|
||||
}
|
||||
|
||||
child_tasks.push_back(std::move(copy));
|
||||
}
|
||||
|
||||
// the task will be moved into queue, then onto slots
|
||||
// however, the state must be kept by caller (e.g., HTTP thread)
|
||||
task_result_state create_state() const {
|
||||
return task_result_state(params.chat_parser_params);
|
||||
}
|
||||
|
||||
bool is_parent() const {
|
||||
return child_tasks.size() > 0;
|
||||
}
|
||||
|
||||
bool is_child() const {
|
||||
return id_parent != -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
int32_t cache_n = -1;
|
||||
|
||||
int32_t prompt_n = -1;
|
||||
double prompt_ms = 0.0;
|
||||
double prompt_per_token_ms = 0.0;
|
||||
double prompt_per_second = 0.0;
|
||||
|
||||
int32_t predicted_n = -1;
|
||||
double predicted_ms = 0.0;
|
||||
double predicted_per_token_ms = 0.0;
|
||||
double predicted_per_second = 0.0;
|
||||
|
||||
// Optional speculative metrics - only included when > 0
|
||||
int32_t draft_n = 0;
|
||||
int32_t draft_n_accepted = 0;
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
struct result_prompt_progress {
|
||||
int32_t total = 0;
|
||||
int32_t cache = 0;
|
||||
int32_t processed = 0;
|
||||
int64_t time_ms = 0;
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
|
||||
size_t index = 0; // to be used for batched tasks
|
||||
|
||||
virtual bool is_error() {
|
||||
// only used by server_task_result_error
|
||||
return false;
|
||||
}
|
||||
virtual bool is_stop() {
|
||||
// only used by server_task_result_cmpl_*
|
||||
return true;
|
||||
}
|
||||
virtual void update(task_result_state &) {
|
||||
// only used by server_task_result_cmpl_*
|
||||
}
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
};
|
||||
|
||||
// using shared_ptr for polymorphism of server_task_result
|
||||
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
float prob;
|
||||
std::string text_to_send;
|
||||
struct prob_info {
|
||||
llama_token tok;
|
||||
std::string txt;
|
||||
float prob;
|
||||
};
|
||||
std::vector<prob_info> probs;
|
||||
|
||||
json to_json(bool post_sampling_probs) const;
|
||||
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
|
||||
|
||||
static float logarithm(float x);
|
||||
|
||||
static std::vector<unsigned char> str_to_bytes(const std::string & str);
|
||||
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
bool include_usage;
|
||||
result_timings timings;
|
||||
std::string prompt;
|
||||
|
||||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_prompt_tokens_cache;
|
||||
int32_t n_tokens_cached;
|
||||
bool has_new_line;
|
||||
std::string stopping_word;
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
|
||||
bool post_sampling_probs;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
task_params generation_params;
|
||||
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg; // to be populated by update()
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
// for OpenAI Responses API
|
||||
std::string oai_resp_id;
|
||||
std::string oai_resp_reasoning_id;
|
||||
std::string oai_resp_message_id;
|
||||
|
||||
virtual bool is_stop() override {
|
||||
return true; // in stream mode, final responses are considered stop
|
||||
}
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
|
||||
|
||||
oai_resp_id = state.oai_resp_id;
|
||||
oai_resp_reasoning_id = state.oai_resp_reasoning_id;
|
||||
oai_resp_message_id = state.oai_resp_message_id;
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json usage_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
json to_json_oaicompat_resp();
|
||||
|
||||
json to_json_oaicompat_resp_stream();
|
||||
|
||||
json to_json_oaicompat_asr();
|
||||
|
||||
json to_json_anthropic();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_prompt_tokens_cache;
|
||||
|
||||
bool post_sampling_probs;
|
||||
bool is_progress = false;
|
||||
completion_token_output prob_output;
|
||||
result_timings timings;
|
||||
result_prompt_progress progress;
|
||||
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
// Streaming state copied from task_result_state for this chunk
|
||||
bool thinking_block_started = false;
|
||||
bool text_block_started = false;
|
||||
|
||||
// for OpenAI Responses API
|
||||
std::string oai_resp_id;
|
||||
std::string oai_resp_reasoning_id;
|
||||
std::string oai_resp_message_id;
|
||||
std::string oai_resp_fc_id;
|
||||
|
||||
// for Anthropic API: track if any reasoning content has been generated
|
||||
bool anthropic_has_reasoning = false;
|
||||
|
||||
virtual bool is_stop() override {
|
||||
return false; // in stream mode, partial responses are not considered stop
|
||||
}
|
||||
|
||||
virtual void update(task_result_state & state) override;
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_oaicompat_resp();
|
||||
|
||||
json to_json_oaicompat_asr();
|
||||
|
||||
json to_json_anthropic();
|
||||
};
|
||||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
std::vector<std::vector<float>> embedding;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
// response formatting
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
};
|
||||
|
||||
struct server_task_result_rerank : server_task_result {
|
||||
float score = -1e6;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_error : server_task_result {
|
||||
error_type err_type = ERROR_TYPE_SERVER;
|
||||
std::string err_msg;
|
||||
|
||||
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_ctx = 0;
|
||||
|
||||
virtual bool is_error() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_metrics : server_task_result {
|
||||
int n_idle_slots;
|
||||
int n_processing_slots;
|
||||
int n_tasks_deferred;
|
||||
int64_t t_start;
|
||||
|
||||
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
|
||||
uint64_t n_prompt_tokens_processed_total = 0;
|
||||
uint64_t t_prompt_processing_total = 0;
|
||||
uint64_t n_tokens_predicted_total = 0;
|
||||
uint64_t t_tokens_generation_total = 0;
|
||||
|
||||
uint64_t n_tokens_max = 0;
|
||||
|
||||
uint64_t n_prompt_tokens_processed = 0;
|
||||
uint64_t t_prompt_processing = 0;
|
||||
|
||||
uint64_t n_tokens_predicted = 0;
|
||||
uint64_t t_tokens_generation = 0;
|
||||
|
||||
uint64_t n_decode_total = 0;
|
||||
uint64_t n_busy_slots_total = 0;
|
||||
|
||||
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
|
||||
// therefore, we use json to temporarily store the slot.to_json() result
|
||||
json slots_data = json::array();
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_slot_save_load : server_task_result {
|
||||
std::string filename;
|
||||
bool is_save; // true = save, false = load
|
||||
|
||||
size_t n_tokens;
|
||||
size_t n_bytes;
|
||||
double t_ms;
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_slot_erase : server_task_result {
|
||||
size_t n_erased;
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_get_lora : server_task_result {
|
||||
struct lora {
|
||||
common_adapter_lora_info info;
|
||||
std::string alora_invocation_string;
|
||||
llama_tokens alora_invocation_tokens;
|
||||
};
|
||||
std::vector<lora> loras;
|
||||
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_prompt_data {
|
||||
std::vector<uint8_t> main;
|
||||
std::vector<uint8_t> drft;
|
||||
|
||||
size_t size() const {
|
||||
return main.size() + drft.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
server_prompt_data data;
|
||||
|
||||
std::list<common_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = 0;
|
||||
|
||||
res += data.size();
|
||||
|
||||
for (const auto & ckpt : checkpoints) {
|
||||
res += ckpt.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
server_prompt clone() const {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
|
||||
this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
|
||||
this->limit_tokens = limit_tokens;
|
||||
}
|
||||
|
||||
std::list<server_prompt> states;
|
||||
|
||||
// in bytes, 0 = no limit
|
||||
size_t limit_size = 0;
|
||||
|
||||
// in tokens, 0 = no limit
|
||||
size_t limit_tokens = 0;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
|
||||
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
817
tools/server/server-tools.cpp
Normal file
817
tools/server/server-tools.cpp
Normal file
@@ -0,0 +1,817 @@
|
||||
#include "server-tools.h"
|
||||
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
#include <algorithm>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
//
|
||||
// internal helpers
|
||||
//
|
||||
|
||||
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
|
||||
std::vector<char *> r;
|
||||
r.reserve(v.size() + 1);
|
||||
for (const auto & s : v) {
|
||||
r.push_back(const_cast<char *>(s.c_str()));
|
||||
}
|
||||
r.push_back(nullptr);
|
||||
return r;
|
||||
}
|
||||
|
||||
struct run_proc_result {
|
||||
std::string output;
|
||||
int exit_code = -1;
|
||||
bool timed_out = false;
|
||||
};
|
||||
|
||||
static run_proc_result run_process(
|
||||
const std::vector<std::string> & args,
|
||||
size_t max_output,
|
||||
int timeout_secs) {
|
||||
run_proc_result res;
|
||||
|
||||
subprocess_s proc;
|
||||
auto argv = to_cstr_vec(args);
|
||||
|
||||
int options = subprocess_option_no_window
|
||||
| subprocess_option_combined_stdout_stderr
|
||||
| subprocess_option_inherit_environment
|
||||
| subprocess_option_search_user_path;
|
||||
|
||||
if (subprocess_create(argv.data(), options, &proc) != 0) {
|
||||
res.output = "failed to spawn process";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::atomic<bool> done{false};
|
||||
std::atomic<bool> timed_out{false};
|
||||
|
||||
std::thread timeout_thread([&]() {
|
||||
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
|
||||
while (!done.load()) {
|
||||
if (std::chrono::steady_clock::now() >= deadline) {
|
||||
timed_out.store(true);
|
||||
subprocess_terminate(&proc);
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
});
|
||||
|
||||
FILE * f = subprocess_stdout(&proc);
|
||||
std::string output;
|
||||
bool truncated = false;
|
||||
if (f) {
|
||||
char buf[4096];
|
||||
while (fgets(buf, sizeof(buf), f) != nullptr) {
|
||||
if (!truncated) {
|
||||
size_t len = strlen(buf);
|
||||
if (output.size() + len <= max_output) {
|
||||
output.append(buf, len);
|
||||
} else {
|
||||
output.append(buf, max_output - output.size());
|
||||
truncated = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done.store(true);
|
||||
if (timeout_thread.joinable()) {
|
||||
timeout_thread.join();
|
||||
}
|
||||
|
||||
subprocess_join(&proc, &res.exit_code);
|
||||
subprocess_destroy(&proc);
|
||||
|
||||
res.output = output;
|
||||
res.timed_out = timed_out.load();
|
||||
if (truncated) {
|
||||
res.output += "\n[output truncated]";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_tool::to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
{"tool", name},
|
||||
{"type", "builtin"},
|
||||
{"permissions", json{
|
||||
{"write", permission_write}
|
||||
}},
|
||||
{"definition", get_definition()},
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
// read_file: read a file with optional line range and line-number prefix
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
|
||||
|
||||
struct server_tool_read_file : server_tool {
|
||||
server_tool_read_file() {
|
||||
name = "read_file";
|
||||
display_name = "Read file";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
|
||||
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
|
||||
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
|
||||
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
|
||||
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
int start_line = json_value(params, "start_line", 1);
|
||||
int end_line = json_value(params, "end_line", -1); // -1 = no limit
|
||||
bool append_loc = json_value(params, "append_loc", false);
|
||||
|
||||
std::error_code ec;
|
||||
uintmax_t file_size = fs::file_size(path, ec);
|
||||
if (ec) {
|
||||
return {{"error", "cannot stat file: " + ec.message()}};
|
||||
}
|
||||
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
|
||||
return {{"error", string_format(
|
||||
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
|
||||
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
|
||||
}
|
||||
|
||||
std::ifstream f(path);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
|
||||
std::string result;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
|
||||
while (std::getline(f, line)) {
|
||||
lineno++;
|
||||
if (lineno < start_line) continue;
|
||||
if (end_line != -1 && lineno > end_line) break;
|
||||
|
||||
std::string out_line;
|
||||
if (append_loc) {
|
||||
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
|
||||
} else {
|
||||
out_line = line + "\n";
|
||||
}
|
||||
|
||||
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
|
||||
result += "[output truncated]";
|
||||
break;
|
||||
}
|
||||
result += out_line;
|
||||
}
|
||||
|
||||
return {{"plain_text_response", result}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// file_glob_search: find files matching a glob pattern under a base directory
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_file_glob_search : server_tool {
|
||||
server_tool_file_glob_search() {
|
||||
name = "file_glob_search";
|
||||
display_name = "File search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Recursively search for files matching a glob pattern under a directory."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string base = params.at("path").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t count = 0;
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : fs::recursive_directory_iterator(base,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), base, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
output_text << entry.path().string() << "\n";
|
||||
if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
output_text << "\n---\nTotal matches: " << count << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// grep_search: search for a regex pattern in files
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_grep_search : server_tool {
|
||||
server_tool_grep_search() {
|
||||
name = "grep_search";
|
||||
display_name = "Grep search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
|
||||
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
|
||||
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "pattern"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string pat_str = params.at("pattern").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
bool show_lineno = json_value(params, "return_line_numbers", false);
|
||||
|
||||
std::regex pattern;
|
||||
try {
|
||||
pattern = std::regex(pat_str);
|
||||
} catch (const std::regex_error & e) {
|
||||
return {{"error", std::string("invalid regex: ") + e.what()}};
|
||||
}
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t total = 0;
|
||||
|
||||
auto search_file = [&](const fs::path & fpath) {
|
||||
std::ifstream f(fpath);
|
||||
if (!f) return;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
|
||||
lineno++;
|
||||
if (std::regex_search(line, pattern)) {
|
||||
output_text << fpath.string() << ":";
|
||||
if (show_lineno) {
|
||||
output_text << lineno << ":";
|
||||
}
|
||||
output_text << line << "\n";
|
||||
total++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::error_code ec;
|
||||
if (fs::is_regular_file(path, ec)) {
|
||||
search_file(path);
|
||||
} else if (fs::is_directory(path, ec)) {
|
||||
for (const auto & entry : fs::recursive_directory_iterator(path,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), path, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
search_file(entry.path());
|
||||
}
|
||||
} else {
|
||||
return {{"error", "path does not exist: " + path}};
|
||||
}
|
||||
|
||||
output_text << "\n\n---\nTotal matches: " << total << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// exec_shell_command: run an arbitrary shell command
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
|
||||
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
|
||||
|
||||
struct server_tool_exec_shell_command : server_tool {
|
||||
server_tool_exec_shell_command() {
|
||||
name = "exec_shell_command";
|
||||
display_name = "Execute shell command";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
|
||||
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
|
||||
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
|
||||
}},
|
||||
{"required", json::array({"command"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string command = params.at("command").get<std::string>();
|
||||
int timeout = json_value(params, "timeout", 10);
|
||||
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
|
||||
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> args = {"cmd", "/c", command};
|
||||
#else
|
||||
std::vector<std::string> args = {"sh", "-c", command};
|
||||
#endif
|
||||
|
||||
auto res = run_process(args, max_output, timeout);
|
||||
|
||||
std::string text_output = res.output;
|
||||
text_output += string_format("\n[exit code: %d]", res.exit_code);
|
||||
if (res.timed_out) {
|
||||
text_output += " [exit due to timed out]";
|
||||
}
|
||||
|
||||
return {{"plain_text_response", text_output}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// write_file: create or overwrite a file
|
||||
//
|
||||
|
||||
struct server_tool_write_file : server_tool {
|
||||
server_tool_write_file() {
|
||||
name = "write_file";
|
||||
display_name = "Write file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to write"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "content"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string content = params.at("content").get<std::string>();
|
||||
|
||||
std::error_code ec;
|
||||
fs::path fpath(path);
|
||||
if (fpath.has_parent_path()) {
|
||||
fs::create_directories(fpath.parent_path(), ec);
|
||||
if (ec) {
|
||||
return {{"error", "failed to create directories: " + ec.message()}};
|
||||
}
|
||||
}
|
||||
|
||||
std::ofstream f(path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
f << content;
|
||||
if (!f) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// edit_file: edit file content via line-based changes
|
||||
//
|
||||
|
||||
struct server_tool_edit_file : server_tool {
|
||||
server_tool_edit_file() {
|
||||
name = "edit_file";
|
||||
display_name = "Edit file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description",
|
||||
"Edit a file by applying a list of line-based changes. "
|
||||
"Each change targets a 1-based inclusive line range and has a mode: "
|
||||
"\"replace\" (replace lines with content), "
|
||||
"\"delete\" (remove lines, content must be empty string), "
|
||||
"\"append\" (insert content after line_end). "
|
||||
"Set line_start to -1 to target the end of file (line_end is ignored in that case). "
|
||||
"Changes must not overlap. They are applied in reverse line order automatically."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file to edit"}}},
|
||||
{"changes", {
|
||||
{"type", "array"},
|
||||
{"description", "List of changes to apply"},
|
||||
{"items", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}},
|
||||
{"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}},
|
||||
{"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}},
|
||||
}},
|
||||
{"required", json::array({"mode", "line_start", "line_end", "content"})},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"path", "changes"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
const json & changes = params.at("changes");
|
||||
|
||||
if (!changes.is_array()) {
|
||||
return {{"error", "\"changes\" must be an array"}};
|
||||
}
|
||||
|
||||
// read file into lines
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
std::vector<std::string> lines;
|
||||
{
|
||||
std::string line;
|
||||
while (std::getline(fin, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
}
|
||||
fin.close();
|
||||
|
||||
// validate and collect changes, then sort descending by line_start
|
||||
struct change_entry {
|
||||
std::string mode;
|
||||
int line_start; // 1-based
|
||||
int line_end; // 1-based inclusive
|
||||
std::string content;
|
||||
};
|
||||
std::vector<change_entry> entries;
|
||||
entries.reserve(changes.size());
|
||||
|
||||
for (const auto & ch : changes) {
|
||||
change_entry e;
|
||||
e.mode = ch.at("mode").get<std::string>();
|
||||
e.line_start = ch.at("line_start").get<int>();
|
||||
e.line_end = ch.at("line_end").get<int>();
|
||||
e.content = ch.at("content").get<std::string>();
|
||||
|
||||
if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") {
|
||||
return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}};
|
||||
}
|
||||
if (e.mode == "delete" && !e.content.empty()) {
|
||||
return {{"error", "content must be empty string for delete mode"}};
|
||||
}
|
||||
int n = (int) lines.size();
|
||||
if (e.line_start == -1) {
|
||||
// -1 means end of file; line_end is ignored — normalize to point past last line
|
||||
e.line_start = n + 1;
|
||||
e.line_end = n + 1;
|
||||
} else {
|
||||
if (e.line_start < 1 || e.line_end < e.line_start) {
|
||||
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
|
||||
}
|
||||
if (e.line_end > n) {
|
||||
return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}};
|
||||
}
|
||||
}
|
||||
entries.push_back(std::move(e));
|
||||
}
|
||||
|
||||
// sort descending so earlier-indexed changes don't shift later ones
|
||||
std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) {
|
||||
return a.line_start > b.line_start;
|
||||
});
|
||||
|
||||
// apply changes (0-based indices internally)
|
||||
for (const auto & e : entries) {
|
||||
int idx_start = e.line_start - 1; // 0-based
|
||||
int idx_end = e.line_end - 1; // 0-based inclusive
|
||||
|
||||
// split content into lines (preserve trailing newline awareness)
|
||||
std::vector<std::string> new_lines;
|
||||
if (!e.content.empty()) {
|
||||
std::istringstream ss(e.content);
|
||||
std::string ln;
|
||||
while (std::getline(ss, ln)) {
|
||||
new_lines.push_back(ln);
|
||||
}
|
||||
// if content ends with \n, getline consumed it — no extra empty line needed
|
||||
// if content does NOT end with \n, last line is still captured correctly
|
||||
}
|
||||
|
||||
if (e.mode == "replace") {
|
||||
// erase [idx_start, idx_end] and insert new_lines
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end());
|
||||
} else if (e.mode == "delete") {
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
} else { // append
|
||||
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
|
||||
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
|
||||
}
|
||||
}
|
||||
|
||||
// write file back
|
||||
std::ofstream fout(path, std::ios::binary);
|
||||
if (!fout) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
for (size_t i = 0; i < lines.size(); i++) {
|
||||
fout << lines[i];
|
||||
if (i + 1 < lines.size()) {
|
||||
fout << "\n";
|
||||
}
|
||||
}
|
||||
if (!lines.empty()) {
|
||||
fout << "\n";
|
||||
}
|
||||
if (!fout) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// apply_diff: apply a unified diff via git apply
|
||||
//
|
||||
|
||||
struct server_tool_apply_diff : server_tool {
|
||||
server_tool_apply_diff() {
|
||||
name = "apply_diff";
|
||||
display_name = "Apply diff";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
|
||||
}},
|
||||
{"required", json::array({"diff"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string diff = params.at("diff").get<std::string>();
|
||||
|
||||
// write diff to a temporary file
|
||||
static std::atomic<int> counter{0};
|
||||
std::string tmp_path = (fs::temp_directory_path() /
|
||||
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
|
||||
|
||||
{
|
||||
std::ofstream f(tmp_path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to create temp patch file"}};
|
||||
}
|
||||
f << diff;
|
||||
}
|
||||
|
||||
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
|
||||
|
||||
std::error_code ec;
|
||||
fs::remove(tmp_path, ec);
|
||||
|
||||
if (res.exit_code != 0) {
|
||||
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
|
||||
}
|
||||
return {{"result", "patch applied successfully"}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// get_datetime: returns the current date and time
|
||||
//
|
||||
|
||||
struct server_tool_get_datetime : server_tool {
|
||||
server_tool_get_datetime() {
|
||||
name = "get_datetime";
|
||||
display_name = "Get Date & Time";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Returns the current date and time"},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json) override {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
|
||||
return {{"result", std::ctime(&time)}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// public API
|
||||
//
|
||||
|
||||
static std::vector<std::unique_ptr<server_tool>> build_tools() {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
tools.push_back(std::make_unique<server_tool_read_file>());
|
||||
tools.push_back(std::make_unique<server_tool_file_glob_search>());
|
||||
tools.push_back(std::make_unique<server_tool_grep_search>());
|
||||
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
|
||||
tools.push_back(std::make_unique<server_tool_write_file>());
|
||||
tools.push_back(std::make_unique<server_tool_edit_file>());
|
||||
tools.push_back(std::make_unique<server_tool_apply_diff>());
|
||||
tools.push_back(std::make_unique<server_tool_get_datetime>());
|
||||
return tools;
|
||||
}
|
||||
|
||||
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
|
||||
if (!enabled_tools.empty()) {
|
||||
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
|
||||
auto all_tools = build_tools();
|
||||
|
||||
// collect all known tool names for validation
|
||||
std::vector<std::string> known_names;
|
||||
known_names.reserve(all_tools.size());
|
||||
for (const auto & t : all_tools) {
|
||||
known_names.push_back(t->name);
|
||||
}
|
||||
|
||||
// validate that every requested tool is known
|
||||
for (const auto & name : enabled_tools) {
|
||||
if (name == "all") continue;
|
||||
if (std::find(known_names.begin(), known_names.end(), name) == known_names.end()) {
|
||||
throw std::runtime_error(string_format(
|
||||
"unknown tool \"%s\". available tools: %s",
|
||||
name.c_str(),
|
||||
string_join(known_names, ", ").c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
tools.clear();
|
||||
for (auto & t : all_tools) {
|
||||
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
|
||||
tools.push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json result = json::array();
|
||||
for (const auto & t : tools) {
|
||||
result.push_back(t->to_json());
|
||||
}
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
std::string tool_name = body.at("tool").get<std::string>();
|
||||
json params = body.value("params", json::object());
|
||||
json result = invoke(tool_name, params);
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const json::exception & e) {
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
json server_tools::invoke(const std::string & name, const json & params) {
|
||||
for (auto & t : tools) {
|
||||
if (t->name == name) {
|
||||
return t->invoke(params);
|
||||
}
|
||||
}
|
||||
return {{"error", "unknown tool: " + name}};
|
||||
}
|
||||
26
tools/server/server-tools.h
Normal file
26
tools/server/server-tools.h
Normal file
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
struct server_tool {
|
||||
std::string name;
|
||||
std::string display_name;
|
||||
bool permission_write = false;
|
||||
|
||||
virtual ~server_tool() = default;
|
||||
virtual json get_definition() = 0;
|
||||
virtual json invoke(json params) = 0;
|
||||
|
||||
json to_json();
|
||||
};
|
||||
|
||||
struct server_tools {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
|
||||
void setup(const std::vector<std::string> & enabled_tools);
|
||||
json invoke(const std::string & name, const json & params);
|
||||
|
||||
server_http_context::handler_t handle_get;
|
||||
server_http_context::handler_t handle_post;
|
||||
};
|
||||
363
tools/server/server.cpp
Normal file
363
tools/server/server.cpp
Normal file
@@ -0,0 +1,363 @@
|
||||
#include "server-context.h"
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
#include "server-tools.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "build-info.h"
|
||||
#include "common.h"
|
||||
#include "fit.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <clocale>
|
||||
#include <exception>
|
||||
#include <signal.h>
|
||||
#include <thread> // for std::thread::hardware_concurrency
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
static std::function<void(int)> shutdown_handler;
|
||||
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
static inline void signal_handler(int signal) {
|
||||
if (is_terminating.test_and_set()) {
|
||||
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
|
||||
// this is for better developer experience, we can remove when the server is stable enough
|
||||
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
// wrapper function that handles exceptions and logs errors
|
||||
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
|
||||
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
|
||||
return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
|
||||
std::string message;
|
||||
error_type error;
|
||||
try {
|
||||
return func(req);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
// treat invalid_argument as invalid request (400)
|
||||
error = ERROR_TYPE_INVALID_REQUEST;
|
||||
message = e.what();
|
||||
} catch (const std::exception & e) {
|
||||
// treat other exceptions as server error (500)
|
||||
error = ERROR_TYPE_SERVER;
|
||||
message = e.what();
|
||||
} catch (...) {
|
||||
error = ERROR_TYPE_SERVER;
|
||||
message = "unknown error";
|
||||
}
|
||||
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 500;
|
||||
try {
|
||||
json error_data = format_error_response(message, error);
|
||||
res->status = json_value(error_data, "code", 500);
|
||||
res->data = safe_json_to_str({{ "error", error_data }});
|
||||
SRV_WRN("got exception: %s\n", res->data.c_str());
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
|
||||
res->data = "Internal Server Error";
|
||||
}
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
// own arguments required by this example
|
||||
common_params params;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// validate batch size for embeddings
|
||||
// embeddings require all tokens to be processed in a single ubatch
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/12836
|
||||
if (params.embedding && params.n_batch > params.n_ubatch) {
|
||||
LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
|
||||
LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
|
||||
params.n_batch = params.n_ubatch;
|
||||
}
|
||||
|
||||
if (params.n_parallel < 0) {
|
||||
LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
|
||||
|
||||
params.n_parallel = 4;
|
||||
params.kv_unified = true;
|
||||
}
|
||||
|
||||
// for consistency between server router mode and single-model mode, we set the same model name as alias
|
||||
if (params.model_alias.empty() && !params.model.name.empty()) {
|
||||
params.model_alias.insert(params.model.name);
|
||||
}
|
||||
|
||||
// struct that contains llama context and inference
|
||||
server_context ctx_server;
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
LOG_INF("build_info: %s\n", llama_build_info());
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
|
||||
server_http_context ctx_http;
|
||||
if (!ctx_http.init(params)) {
|
||||
LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Router
|
||||
//
|
||||
|
||||
// register API routes
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
bool is_router_server = params.model.path.empty();
|
||||
std::optional<server_models_routes> models_routes{};
|
||||
if (is_router_server) {
|
||||
// setup server instances manager
|
||||
try {
|
||||
models_routes.emplace(params, argc, argv);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// proxy handlers
|
||||
// note: routes.get_health stays the same
|
||||
routes.get_metrics = models_routes->proxy_get;
|
||||
routes.post_props = models_routes->proxy_post;
|
||||
routes.post_completions = models_routes->proxy_post;
|
||||
routes.post_completions_oai = models_routes->proxy_post;
|
||||
routes.post_chat_completions = models_routes->proxy_post;
|
||||
routes.post_responses_oai = models_routes->proxy_post;
|
||||
routes.post_transcriptions_oai = models_routes->proxy_post;
|
||||
routes.post_anthropic_messages = models_routes->proxy_post;
|
||||
routes.post_anthropic_count_tokens = models_routes->proxy_post;
|
||||
routes.post_infill = models_routes->proxy_post;
|
||||
routes.post_embeddings = models_routes->proxy_post;
|
||||
routes.post_embeddings_oai = models_routes->proxy_post;
|
||||
routes.post_rerank = models_routes->proxy_post;
|
||||
routes.post_tokenize = models_routes->proxy_post;
|
||||
routes.post_detokenize = models_routes->proxy_post;
|
||||
routes.post_apply_template = models_routes->proxy_post;
|
||||
routes.get_lora_adapters = models_routes->proxy_get;
|
||||
routes.post_lora_adapters = models_routes->proxy_post;
|
||||
routes.get_slots = models_routes->proxy_get;
|
||||
routes.post_slots = models_routes->proxy_post;
|
||||
|
||||
// custom routes for router
|
||||
routes.get_props = models_routes->get_router_props;
|
||||
routes.get_models = models_routes->get_router_models;
|
||||
|
||||
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
|
||||
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
|
||||
}
|
||||
|
||||
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
|
||||
ctx_http.get ("/props", ex_wrapper(routes.get_props));
|
||||
ctx_http.post("/props", ex_wrapper(routes.post_props));
|
||||
ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
|
||||
ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
|
||||
ctx_http.post("/completions", ex_wrapper(routes.post_completions));
|
||||
ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
|
||||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/v1/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
|
||||
ctx_http.post("/audio/transcriptions", ex_wrapper(routes.post_transcriptions_oai));
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||
ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
|
||||
ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
|
||||
ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
|
||||
ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
|
||||
ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
|
||||
// LoRA adapters hotswap
|
||||
ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
|
||||
ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
|
||||
// Save & load slots
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
|
||||
// Google Cloud Platform (Vertex AI) compat
|
||||
ctx_http.register_gcp_compat();
|
||||
|
||||
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
|
||||
if (params.webui_mcp_proxy) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
}
|
||||
// EXPERIMENTAL built-in tools
|
||||
if (!params.server_tools.empty()) {
|
||||
try {
|
||||
tools.setup(params.server_tools);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: tools setup failed: %s\n", __func__, e.what());
|
||||
return 1;
|
||||
}
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
||||
std::function<void()> clean_up;
|
||||
|
||||
if (is_router_server) {
|
||||
LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
|
||||
|
||||
clean_up = [&models_routes]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
if (models_routes.has_value()) {
|
||||
models_routes->models.unload_all();
|
||||
}
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
if (!ctx_http.start()) {
|
||||
clean_up();
|
||||
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
ctx_http.is_ready.store(true);
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_http.stop();
|
||||
};
|
||||
|
||||
} else {
|
||||
// setup clean up function, to be called before exit
|
||||
clean_up = [&ctx_http, &ctx_server]() {
|
||||
SRV_INF("%s: cleaning up before exit...\n", __func__);
|
||||
ctx_http.stop();
|
||||
ctx_server.terminate();
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
// start the HTTP server before loading the model to be able to serve /health requests
|
||||
if (!ctx_http.start()) {
|
||||
clean_up();
|
||||
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// load the model
|
||||
LOG_INF("%s: loading model\n", __func__);
|
||||
|
||||
if (server_models::is_child_server()) {
|
||||
ctx_server.on_sleeping_changed([&](bool sleeping) {
|
||||
server_models::notify_router_sleeping_state(sleeping);
|
||||
});
|
||||
}
|
||||
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
if (ctx_http.thread.joinable()) {
|
||||
ctx_http.thread.join();
|
||||
}
|
||||
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
routes.update_meta(ctx_server);
|
||||
ctx_http.is_ready.store(true);
|
||||
|
||||
LOG_INF("%s: model loaded\n", __func__);
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
// this will unblock start_loop()
|
||||
ctx_server.terminate();
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: refactor in common/console
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
if (is_router_server) {
|
||||
LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
|
||||
LOG_INF("%s: NOTE: router mode is experimental\n", __func__);
|
||||
LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__);
|
||||
if (ctx_http.thread.joinable()) {
|
||||
ctx_http.thread.join(); // keep the main thread alive
|
||||
}
|
||||
|
||||
// when the HTTP server stops, clean up and exit
|
||||
clean_up();
|
||||
} else {
|
||||
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
|
||||
LOG_INF("%s: starting the main loop...\n", __func__);
|
||||
|
||||
// optionally, notify router server that this instance is ready
|
||||
std::thread monitor_thread;
|
||||
if (server_models::is_child_server()) {
|
||||
json model_info = routes.get_model_info();
|
||||
monitor_thread = server_models::setup_child_server(shutdown_handler, model_info);
|
||||
}
|
||||
|
||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||
ctx_server.start_loop();
|
||||
|
||||
clean_up();
|
||||
if (ctx_http.thread.joinable()) {
|
||||
ctx_http.thread.join();
|
||||
}
|
||||
if (monitor_thread.joinable()) {
|
||||
monitor_thread.join();
|
||||
}
|
||||
|
||||
auto * ll_ctx = ctx_server.get_llama_context();
|
||||
if (ll_ctx != nullptr) {
|
||||
common_memory_breakdown_print(ll_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
2
tools/server/tests/.gitignore
vendored
Normal file
2
tools/server/tests/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.venv
|
||||
tmp
|
||||
96
tools/server/tests/README.md
Normal file
96
tools/server/tests/README.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# Server tests
|
||||
|
||||
Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/).
|
||||
|
||||
Tests target GitHub workflows job runners with 4 vCPU.
|
||||
|
||||
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
|
||||
To mitigate it, you can increase values in `n_predict`, `kv_size`.
|
||||
|
||||
### Install dependencies
|
||||
|
||||
`pip install -r requirements.txt`
|
||||
|
||||
### Run tests
|
||||
|
||||
1. Build the server
|
||||
|
||||
```shell
|
||||
cd ../../..
|
||||
cmake -B build
|
||||
cmake --build build --target llama-server
|
||||
```
|
||||
|
||||
2. Start the test: `./tests.sh`
|
||||
|
||||
It's possible to override some scenario steps values with environment variables:
|
||||
|
||||
| variable | description |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------|
|
||||
| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
|
||||
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` |
|
||||
| `DEBUG` | to enable steps and server verbose mode `--verbose` |
|
||||
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
|
||||
| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this |
|
||||
|
||||
To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed):
|
||||
|
||||
```shell
|
||||
SLOW_TESTS=1 ./tests.sh
|
||||
```
|
||||
|
||||
To run with stdout/stderr display in real time (verbose output, but useful for debugging):
|
||||
|
||||
```shell
|
||||
DEBUG=1 ./tests.sh -s -v -x
|
||||
```
|
||||
|
||||
To run all the tests in a file:
|
||||
|
||||
```shell
|
||||
./tests.sh unit/test_chat_completion.py -v -x
|
||||
```
|
||||
|
||||
To run a single test:
|
||||
|
||||
```shell
|
||||
./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req
|
||||
```
|
||||
|
||||
Hint: You can compile and run test in single command, useful for local development:
|
||||
|
||||
```shell
|
||||
cmake --build build -j --target llama-server && ./tools/server/tests/tests.sh
|
||||
```
|
||||
|
||||
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)
|
||||
|
||||
### Debugging external llama-server
|
||||
It can sometimes be useful to run the server in a debugger when invesigating test
|
||||
failures. To do this, the environment variable `DEBUG_EXTERNAL=1` can be set
|
||||
which will cause the test to skip starting a llama-server itself. Instead, the
|
||||
server can be started in a debugger.
|
||||
|
||||
Example using `gdb`:
|
||||
```console
|
||||
$ gdb --args ../../../build/bin/llama-server \
|
||||
--host 127.0.0.1 --port 8080 \
|
||||
--temp 0.8 --seed 42 \
|
||||
--hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf \
|
||||
--batch-size 32 --no-slots --alias tinyllama-2 --ctx-size 512 \
|
||||
--parallel 2 --n-predict 64
|
||||
```
|
||||
And a break point can be set in before running:
|
||||
```console
|
||||
(gdb) br server.cpp:4604
|
||||
(gdb) r
|
||||
main: server is listening on http://127.0.0.1:8080 - starting the main loop
|
||||
srv update_slots: all slots are idle
|
||||
```
|
||||
|
||||
And then the test in question can be run in another terminal:
|
||||
```console
|
||||
(venv) $ env DEBUG_EXTERNAL=1 ./tests.sh unit/test_chat_completion.py -v -x
|
||||
```
|
||||
And this should trigger the breakpoint and allow inspection of the server state
|
||||
in the debugger terminal.
|
||||
21
tools/server/tests/conftest.py
Normal file
21
tools/server/tests/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
|
||||
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
|
||||
@pytest.fixture(autouse=True)
|
||||
def stop_server_after_each_test():
|
||||
# do nothing before each test
|
||||
yield
|
||||
# stop all servers after each test
|
||||
instances = set(
|
||||
server_instances
|
||||
) # copy the set to prevent 'Set changed size during iteration'
|
||||
for server in instances:
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def do_something():
|
||||
# this will be run once per test session, before any tests
|
||||
ServerPreset.load_all()
|
||||
4
tools/server/tests/pytest.ini
Normal file
4
tools/server/tests/pytest.ini
Normal file
@@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
serial
|
||||
8
tools/server/tests/requirements.txt
Normal file
8
tools/server/tests/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
aiohttp~=3.9.3
|
||||
pytest~=8.3.3
|
||||
huggingface_hub>=1.5.0,<2.0
|
||||
numpy~=1.26.4
|
||||
openai~=2.14.0
|
||||
prometheus-client~=0.20.0
|
||||
requests~=2.32.3
|
||||
wget~=3.2
|
||||
23
tools/server/tests/tests.sh
Executable file
23
tools/server/tests/tests.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# make sure we are in the right directory
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
set -eu
|
||||
|
||||
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
|
||||
# Slow tests for tool calls need quite a few models ahead of time to avoid timing out.
|
||||
python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py
|
||||
fi
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
|
||||
pytest -v -x
|
||||
else
|
||||
pytest -v -x -m "not slow"
|
||||
fi
|
||||
else
|
||||
pytest "$@"
|
||||
fi
|
||||
113
tools/server/tests/unit/test_basic.py
Normal file
113
tools/server/tests/unit/test_basic.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import requests
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_server_start_simple():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_server_props():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert ".gguf" in res.body["model_path"]
|
||||
assert res.body["total_slots"] == server.n_slots
|
||||
default_val = res.body["default_generation_settings"]
|
||||
assert server.n_ctx is not None and server.n_slots is not None
|
||||
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
|
||||
assert default_val["params"]["seed"] == server.seed
|
||||
|
||||
|
||||
def test_server_models():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["data"]) == 1
|
||||
assert res.body["data"][0]["id"] == server.model_alias
|
||||
|
||||
|
||||
def test_server_slots():
|
||||
global server
|
||||
|
||||
# without slots endpoint enabled, this should return error
|
||||
server.server_slots = False
|
||||
server.start()
|
||||
res = server.make_request("GET", "/slots")
|
||||
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
|
||||
assert "error" in res.body
|
||||
server.stop()
|
||||
|
||||
# with slots endpoint enabled, this should return slots info
|
||||
server.server_slots = True
|
||||
server.n_slots = 2
|
||||
server.start()
|
||||
res = server.make_request("GET", "/slots")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == server.n_slots
|
||||
assert server.n_ctx is not None and server.n_slots is not None
|
||||
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
|
||||
assert "params" not in res.body[0]
|
||||
|
||||
|
||||
def test_load_split_model():
|
||||
global server
|
||||
server.offline = False
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
||||
server.model_alias = "tinyllama-split"
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 16,
|
||||
"prompt": "Hello",
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(little|girl)+", res.body["content"])
|
||||
|
||||
|
||||
def test_no_webui():
|
||||
global server
|
||||
# default: webui enabled
|
||||
server.start()
|
||||
url = f"http://{server.server_host}:{server.server_port}"
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 200
|
||||
assert "<!doctype html>" in res.text
|
||||
server.stop()
|
||||
|
||||
# with --no-webui
|
||||
server.no_webui = True
|
||||
server.start()
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_server_model_aliases_and_tags():
|
||||
global server
|
||||
server.model_alias = "tinyllama-2,fim,code"
|
||||
server.model_tags = "chat,fim,small"
|
||||
server.start()
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["data"]) == 1
|
||||
model = res.body["data"][0]
|
||||
# aliases field must contain all aliases
|
||||
assert set(model["aliases"]) == {"tinyllama-2", "fim", "code"}
|
||||
# tags field must contain all tags
|
||||
assert set(model["tags"]) == {"chat", "fim", "small"}
|
||||
# id is derived from first alias (alphabetical order from std::set)
|
||||
assert model["id"] == "code"
|
||||
534
tools/server/tests/unit/test_chat_completion.py
Normal file
534
tools/server/tests/unit/test_chat_completion.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||
[
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.chat_template = chat_template
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
|
||||
assert res.body["system_fingerprint"].startswith("b")
|
||||
# we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668
|
||||
# assert res.body["model"] == model if model is not None else server.model_alias
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
|
||||
|
||||
def test_chat_completion_cached_tokens():
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.start()
|
||||
seq = [
|
||||
("1 2 3 4 5 6", 77, 0),
|
||||
("1 2 3 4 5 6", 77, 76),
|
||||
("1 2 3 4 5 9", 77, 51),
|
||||
("1 2 3 9 9 9", 77, 47),
|
||||
]
|
||||
for user_prompt, n_prompt, n_cache in seq:
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Test"},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
})
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.model_alias = "llama-test-model"
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
last_cmpl_id = None
|
||||
for i, data in enumerate(res):
|
||||
if data["choices"]:
|
||||
choice = data["choices"][0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice["delta"]["content"] is None
|
||||
assert choice["delta"]["role"] == "assistant"
|
||||
else:
|
||||
assert "role" not in choice["delta"]
|
||||
assert data["system_fingerprint"].startswith("b")
|
||||
assert data["model"] == "llama-test-model"
|
||||
if last_cmpl_id is None:
|
||||
last_cmpl_id = data["id"]
|
||||
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert "content" not in choice["delta"]
|
||||
assert match_regex(re_content, content)
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"] or ''
|
||||
else:
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
|
||||
|
||||
def test_chat_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=8,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
||||
|
||||
def test_chat_template():
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.debug = True # to get the "__verbose" object in the response
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill,re_prefill", [
|
||||
("Whill", "Whill"),
|
||||
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
|
||||
])
|
||||
def test_chat_template_assistant_prefill(prefill, re_prefill):
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.debug = True # to get the "__verbose" object in the response
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
{"role": "assistant", "content": prefill},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
|
||||
|
||||
|
||||
def test_apply_chat_template():
|
||||
global server
|
||||
server.chat_template = "command-r"
|
||||
server.start()
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a test."},
|
||||
{"role": "user", "content":"Hi there"},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "prompt" in res.body
|
||||
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
|
||||
({"type": "json_object"}, 10, "(\\{|John)+"),
|
||||
({"type": "sound"}, 0, None),
|
||||
# invalid response format (expected to fail)
|
||||
({"type": "json_object", "schema": 123}, 0, None),
|
||||
({"type": "json_object", "schema": {"type": 123}}, 0, None),
|
||||
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
|
||||
])
|
||||
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"response_format": response_format,
|
||||
})
|
||||
if re_content is not None:
|
||||
assert res.status_code == 200
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
else:
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
|
||||
(False, {"const": "42"}, 6, "\"42\""),
|
||||
(True, {"const": "42"}, 6, "\"42\""),
|
||||
])
|
||||
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.debug = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"json_schema": json_schema,
|
||||
})
|
||||
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
|
||||
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||
])
|
||||
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Does not matter what I say, does it?"},
|
||||
],
|
||||
"grammar": grammar,
|
||||
})
|
||||
assert res.status_code == 200, res.body
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", [
|
||||
None,
|
||||
"string",
|
||||
[123],
|
||||
[{}],
|
||||
[{"role": 123}],
|
||||
[{"role": "system", "content": 123}],
|
||||
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||
[{"role": "system", "content": "test"}, {}],
|
||||
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
|
||||
])
|
||||
def test_invalid_chat_completion_req(messages):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"messages": messages,
|
||||
})
|
||||
assert res.status_code == 400 or res.status_code == 500
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
def test_chat_completion_with_timings_per_token():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"timings_per_token": True,
|
||||
})
|
||||
stats_received = False
|
||||
for i, data in enumerate(res):
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert data["choices"][0]["delta"]["content"] is None
|
||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||
else:
|
||||
if data["choices"]:
|
||||
assert "role" not in data["choices"][0]["delta"]
|
||||
else:
|
||||
assert "timings" in data
|
||||
assert "prompt_per_second" in data["timings"]
|
||||
assert "predicted_per_second" in data["timings"]
|
||||
assert "predicted_n" in data["timings"]
|
||||
assert data["timings"]["predicted_n"] <= 10
|
||||
stats_received = True
|
||||
assert stats_received
|
||||
|
||||
|
||||
def test_logprobs():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
)
|
||||
output_text = res.choices[0].message.content
|
||||
aggregated_text = ''
|
||||
assert res.choices[0].logprobs is not None
|
||||
assert res.choices[0].logprobs.content is not None
|
||||
for token in res.choices[0].logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
||||
|
||||
def test_logprobs_stream():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
aggregated_text = ''
|
||||
for i, data in enumerate(res):
|
||||
if data.choices:
|
||||
choice = data.choices[0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice.delta.content is None
|
||||
assert choice.delta.role == "assistant"
|
||||
else:
|
||||
assert choice.delta.role is None
|
||||
if choice.finish_reason is None:
|
||||
if choice.delta.content:
|
||||
output_text += choice.delta.content
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
for token in choice.logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert token.top_logprobs is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
||||
|
||||
def test_logit_bias():
|
||||
global server
|
||||
server.start()
|
||||
|
||||
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
|
||||
|
||||
res = server.make_request("POST", "/tokenize", data={
|
||||
"content": " " + " ".join(exclude) + " ",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
tokens = res.body["tokens"]
|
||||
logit_bias = {tok: -100 for tok in tokens}
|
||||
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=64,
|
||||
logit_bias=logit_bias
|
||||
)
|
||||
output_text = res.choices[0].message.content
|
||||
assert output_text
|
||||
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
||||
|
||||
def test_context_size_exceeded():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
] * 100, # make the prompt too long
|
||||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
assert res.body["error"]["type"] == "exceed_context_size_error"
|
||||
assert res.body["error"]["n_prompt_tokens"] > 0
|
||||
assert server.n_ctx is not None
|
||||
assert server.n_slots is not None
|
||||
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
||||
|
||||
|
||||
def test_context_size_exceeded_stream():
|
||||
global server
|
||||
server.start()
|
||||
try:
|
||||
for _ in server.make_stream_request("POST", "/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
] * 100, # make the prompt too long
|
||||
"stream": True}):
|
||||
pass
|
||||
assert False, "Should have failed"
|
||||
except ServerError as e:
|
||||
assert e.code == 400
|
||||
assert "error" in e.body
|
||||
assert e.body["error"]["type"] == "exceed_context_size_error"
|
||||
assert e.body["error"]["n_prompt_tokens"] > 0
|
||||
assert server.n_ctx is not None
|
||||
assert server.n_slots is not None
|
||||
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_batch,batch_count,reuse_cache",
|
||||
[
|
||||
(64, 4, False),
|
||||
(64, 2, True),
|
||||
]
|
||||
)
|
||||
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
server.n_ctx = 256
|
||||
server.n_slots = 1
|
||||
server.start()
|
||||
def make_cmpl_request():
|
||||
return server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{"role": "user", "content": "This is a test" * 10},
|
||||
],
|
||||
"stream": True,
|
||||
"return_progress": True,
|
||||
})
|
||||
if reuse_cache:
|
||||
# make a first request to populate the cache
|
||||
res0 = make_cmpl_request()
|
||||
for _ in res0:
|
||||
pass # discard the output
|
||||
|
||||
res = make_cmpl_request()
|
||||
last_progress = None
|
||||
total_batch_count = 0
|
||||
|
||||
for data in res:
|
||||
cur_progress = data.get("prompt_progress", None)
|
||||
if cur_progress is None:
|
||||
continue
|
||||
if total_batch_count == 0:
|
||||
# first progress report must have n_cache == n_processed
|
||||
assert cur_progress["total"] > 0
|
||||
assert cur_progress["cache"] == cur_progress["processed"]
|
||||
if reuse_cache:
|
||||
# when reusing cache, we expect some cached tokens
|
||||
assert cur_progress["cache"] > 0
|
||||
if last_progress is not None:
|
||||
assert cur_progress["total"] == last_progress["total"]
|
||||
assert cur_progress["cache"] == last_progress["cache"]
|
||||
assert cur_progress["processed"] > last_progress["processed"]
|
||||
total_batch_count += 1
|
||||
last_progress = cur_progress
|
||||
|
||||
# last progress should indicate completion (all tokens processed)
|
||||
assert last_progress is not None
|
||||
assert last_progress["total"] > 0
|
||||
assert last_progress["processed"] == last_progress["total"]
|
||||
assert total_batch_count == batch_count
|
||||
|
||||
|
||||
def test_chat_completions_multiple_choices():
|
||||
global server
|
||||
server.start()
|
||||
# make sure cache can be reused across multiple choices and multiple requests
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/18663
|
||||
for _ in range(2):
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"n": 2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
# test forcing the same slot to be used
|
||||
# the scheduler should not be locked up in this case
|
||||
"id_slot": 0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert choice["finish_reason"] == "length"
|
||||
1031
tools/server/tests/unit/test_compat_anthropic.py
Normal file
1031
tools/server/tests/unit/test_compat_anthropic.py
Normal file
File diff suppressed because it is too large
Load Diff
60
tools/server/tests/unit/test_compat_gcp.py
Normal file
60
tools/server/tests/unit/test_compat_gcp.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.gcp_compat = True
|
||||
|
||||
|
||||
def test_gcp_predict_camel_case():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/predict", data={
|
||||
"instances": [
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the meaning of life?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "predictions" in res.body
|
||||
assert len(res.body["predictions"]) == 1
|
||||
prediction = res.body["predictions"][0]
|
||||
assert "choices" in prediction
|
||||
assert len(prediction["choices"]) == 1
|
||||
assert prediction["choices"][0]["message"]["role"] == "assistant"
|
||||
assert len(prediction["choices"][0]["message"]["content"]) > 0
|
||||
|
||||
|
||||
def test_gcp_predict_multiple_instances():
|
||||
global server
|
||||
server.n_slots = 2
|
||||
server.start()
|
||||
res = server.make_request("POST", "/predict", data={
|
||||
"instances": [
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
},
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [{"role": "user", "content": "Say world"}],
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["predictions"]) == 2
|
||||
for prediction in res.body["predictions"]:
|
||||
assert "choices" in prediction
|
||||
assert len(prediction["choices"][0]["message"]["content"]) > 0
|
||||
73
tools/server/tests/unit/test_compat_oai_responses.py
Normal file
73
tools/server/tests/unit/test_compat_oai_responses.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
def test_responses_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.responses.create(
|
||||
model="gpt-4.1",
|
||||
input=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_output_tokens=8,
|
||||
temperature=0.8,
|
||||
)
|
||||
assert res.id.startswith("resp_")
|
||||
assert res.output[0].id is not None
|
||||
assert res.output[0].id.startswith("msg_")
|
||||
assert match_regex("(Suddenly)+", res.output_text)
|
||||
|
||||
def test_responses_stream_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
stream = client.responses.create(
|
||||
model="gpt-4.1",
|
||||
input=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_output_tokens=8,
|
||||
temperature=0.8,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
gathered_text = ''
|
||||
resp_id = ''
|
||||
msg_id = ''
|
||||
for r in stream:
|
||||
if r.type == "response.created":
|
||||
assert r.response.id.startswith("resp_")
|
||||
resp_id = r.response.id
|
||||
if r.type == "response.in_progress":
|
||||
assert r.response.id == resp_id
|
||||
if r.type == "response.output_item.added":
|
||||
assert r.item.id is not None
|
||||
assert r.item.id.startswith("msg_")
|
||||
msg_id = r.item.id
|
||||
if (r.type == "response.content_part.added" or
|
||||
r.type == "response.output_text.delta" or
|
||||
r.type == "response.output_text.done" or
|
||||
r.type == "response.content_part.done"):
|
||||
assert r.item_id == msg_id
|
||||
if r.type == "response.output_item.done":
|
||||
assert r.item.id == msg_id
|
||||
|
||||
if r.type == "response.output_text.delta":
|
||||
gathered_text += r.delta
|
||||
if r.type == "response.completed":
|
||||
assert r.response.id.startswith("resp_")
|
||||
assert r.response.output[0].id is not None
|
||||
assert r.response.output[0].id.startswith("msg_")
|
||||
assert gathered_text == r.response.output_text
|
||||
assert match_regex("(Suddenly)+", r.response.output_text)
|
||||
661
tools/server/tests/unit/test_completion.py
Normal file
661
tools/server/tests/unit/test_completion.py
Normal file
@@ -0,0 +1,661 @@
|
||||
import pytest
|
||||
import requests
|
||||
import time
|
||||
import random
|
||||
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
||||
])
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"return_tokens": return_tokens,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||
assert res.body["timings"]["predicted_n"] == n_predicted
|
||||
assert res.body["truncated"] == truncated
|
||||
assert type(res.body["has_new_line"]) == bool
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
if return_tokens:
|
||||
assert len(res.body["tokens"]) > 0
|
||||
assert all(type(tok) == int for tok in res.body["tokens"])
|
||||
else:
|
||||
assert res.body["tokens"] == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
])
|
||||
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
assert "stop" in data and type(data["stop"]) == bool
|
||||
if data["stop"]:
|
||||
assert data["timings"]["prompt_n"] == n_prompt
|
||||
assert data["timings"]["predicted_n"] == n_predicted
|
||||
assert data["truncated"] == truncated
|
||||
assert data["stop_type"] == "limit"
|
||||
assert type(data["has_new_line"]) == bool
|
||||
assert "generation_settings" in data
|
||||
assert server.n_predict is not None
|
||||
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
|
||||
assert data["generation_settings"]["seed"] == server.seed
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
assert len(data["tokens"]) > 0
|
||||
assert all(type(tok) == int for tok in data["tokens"])
|
||||
content += data["content"]
|
||||
|
||||
|
||||
def test_completion_stream_vs_non_stream():
|
||||
global server
|
||||
server.start()
|
||||
res_stream = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"stream": True,
|
||||
})
|
||||
res_non_stream = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
})
|
||||
content_stream = ""
|
||||
for data in res_stream:
|
||||
content_stream += data["content"]
|
||||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="I believe the meaning of life is",
|
||||
max_tokens=8,
|
||||
)
|
||||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].text is not None
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="I believe the meaning of life is",
|
||||
max_tokens=8,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
for data in res:
|
||||
choice = data.choices[0]
|
||||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("(going|bed)+", output_text)
|
||||
|
||||
|
||||
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
|
||||
@pytest.mark.slow
|
||||
def test_completion_stream_with_openai_library_stops():
|
||||
global server
|
||||
server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
|
||||
server.model_hf_file = None
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="System: You are helpful assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
|
||||
stop=["User:\n", "Assistant:\n"],
|
||||
max_tokens=200,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
for data in res:
|
||||
choice = data.choices[0]
|
||||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.start()
|
||||
last_res = None
|
||||
for _ in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_different_result_different_seed(n_slots: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.start()
|
||||
last_res = None
|
||||
for seed in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": seed,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] != last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
# TODO figure why it don't work with temperature = 1
|
||||
# @pytest.mark.parametrize("temperature", [0.0, 1.0])
|
||||
@pytest.mark.parametrize("n_batch", [16, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.0])
|
||||
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
server.start()
|
||||
last_res = None
|
||||
for _ in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": temperature,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
|
||||
def test_cache_vs_nocache_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res_cache = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
res_no_cache = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||
|
||||
|
||||
def test_nocache_long_input_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is"*32,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 400
|
||||
|
||||
def test_json_prompt_no_mtmd():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
def test_json_prompt_mtm_error_when_not_supported():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
# MTMD is disabled on this model, so this should fail.
|
||||
assert res.status_code != 200
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
server.temperature = 0.0
|
||||
server.start()
|
||||
prompt_str = "I believe the meaning of life is"
|
||||
res = server.make_request("POST", "/tokenize", data={
|
||||
"content": prompt_str,
|
||||
"add_special": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
tokens = res.body["tokens"]
|
||||
|
||||
# single completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": tokens,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
||||
# batch completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, tokens],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, prompt_str],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed JSON and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [
|
||||
tokens,
|
||||
{
|
||||
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||
(1, 3),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 2), # some slots must be idle
|
||||
(4, 6),
|
||||
])
|
||||
def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.temperature = 0.0
|
||||
server.start()
|
||||
|
||||
PROMPTS = [
|
||||
("Write a very long book.", "(very|special|big)+"),
|
||||
("Write another a poem.", "(small|house)+"),
|
||||
("What is LLM?", "(Dad|said)+"),
|
||||
("The sky is blue and I love it.", "(climb|leaf)+"),
|
||||
("Write another very long music lyrics.", "(friends|step|sky)+"),
|
||||
("Write a very long joke.", "(cat|Whiskers)+"),
|
||||
]
|
||||
def check_slots_status():
|
||||
should_all_slots_busy = n_requests >= n_slots
|
||||
time.sleep(0.1)
|
||||
res = server.make_request("GET", "/slots")
|
||||
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
|
||||
if should_all_slots_busy:
|
||||
assert n_busy == n_slots
|
||||
else:
|
||||
assert n_busy <= n_slots
|
||||
|
||||
tasks = []
|
||||
for i in range(n_requests):
|
||||
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||
tasks.append((server.make_request, ("POST", "/completion", {
|
||||
"prompt": prompt,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
})))
|
||||
tasks.append((check_slots_status, ()))
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
# check results
|
||||
for i in range(n_requests):
|
||||
prompt, re_content = PROMPTS[i % len(PROMPTS)]
|
||||
res = results[i]
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
assert len(res.body["content"]) > 10
|
||||
# FIXME: the result is not deterministic when using other slot than slot 0
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_ctx,n_slots,n_predict_vals,expected_success",
|
||||
[
|
||||
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
|
||||
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
|
||||
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
|
||||
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
|
||||
],
|
||||
)
|
||||
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.kv_unified = True
|
||||
server.n_ctx = n_ctx
|
||||
server.start()
|
||||
prompt = "A"
|
||||
tasks = []
|
||||
for n_predict in n_predict_vals:
|
||||
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
|
||||
results = parallel_function_calls(tasks)
|
||||
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
|
||||
if expect_ok:
|
||||
assert res.status_code == 200
|
||||
|
||||
# note: https://github.com/ggml-org/llama.cpp/pull/18700#issuecomment-3728695581
|
||||
if res.status_code == 200:
|
||||
assert "content" in res.body
|
||||
if "timings" in res.body:
|
||||
assert res.body["timings"]["predicted_n"] == n_predict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,n_predict,response_fields",
|
||||
[
|
||||
("I believe the meaning of life is", 8, []),
|
||||
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
|
||||
],
|
||||
)
|
||||
def test_completion_response_fields(
|
||||
prompt: str, n_predict: int, response_fields: list[str]
|
||||
):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/completion",
|
||||
data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"response_fields": response_fields,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
assert len(res.body["content"])
|
||||
if len(response_fields):
|
||||
assert res.body["generation_settings/n_predict"] == n_predict
|
||||
assert res.body["prompt"] == "<s> " + prompt
|
||||
assert isinstance(res.body["content"], str)
|
||||
assert len(res.body) == len(response_fields)
|
||||
else:
|
||||
assert len(res.body)
|
||||
assert "generation_settings" in res.body
|
||||
|
||||
|
||||
def test_n_probs():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_stream():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"stream": True,
|
||||
})
|
||||
for data in res:
|
||||
if data["stop"] == False:
|
||||
assert "completion_probabilities" in data
|
||||
assert len(data["completion_probabilities"]) == 1
|
||||
for tok in data["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_post_sampling():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Today was the day. Today I would finally become a",
|
||||
"n_probs": 10,
|
||||
"temperature": 1.0,
|
||||
"n_predict": 5,
|
||||
"post_sampling_probs": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for (i, tok) in enumerate(res.body["completion_probabilities"]):
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert "top_probs" in tok and type(tok["top_probs"]) == list
|
||||
|
||||
for prob in tok["top_probs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
# 0.0 probability tokens should never be returned by the server
|
||||
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
if i == 0:
|
||||
# The prompt is vague enough that we should get at least 10 possibilities
|
||||
# for the first token.
|
||||
assert len(tok["top_probs"]) == 10
|
||||
|
||||
if len(tok["top_probs"]) < 10:
|
||||
# Getting less than the requested number of probabilities should only happen
|
||||
# if the ones we did get already sum to 1.0.
|
||||
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
|
||||
|
||||
def test_n_probs_post_backend_sampling():
|
||||
"""Verify that the same probabilities are returned with and without backend sampling."""
|
||||
global server
|
||||
server.backend_sampling = True
|
||||
server.start()
|
||||
|
||||
def make_request(backend_sampling):
|
||||
n_predict = 20
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "The countries of Europe, in random order, are:",
|
||||
"n_probs": 10,
|
||||
"n_predict": n_predict,
|
||||
"post_sampling_probs": True,
|
||||
"seed": 4242,
|
||||
"backend_sampling": backend_sampling,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
total_probs = 0
|
||||
completions = res.body["completion_probabilities"]
|
||||
assert len(completions) == n_predict
|
||||
for tok in completions:
|
||||
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
|
||||
# data.
|
||||
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
|
||||
total_probs += len(tok["top_probs"])
|
||||
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
|
||||
assert total_probs >= 2 * n_predict
|
||||
return completions
|
||||
|
||||
def verify_token(a, b):
|
||||
assert a["id"] == b["id"]
|
||||
assert a["token"] == b["token"]
|
||||
assert a["bytes"] == b["bytes"]
|
||||
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
|
||||
|
||||
for (a, b) in zip(make_request(True), make_request(False)):
|
||||
verify_token(a, b)
|
||||
assert len(a["top_probs"]) == len(b["top_probs"])
|
||||
|
||||
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
|
||||
verify_token(aa, bb)
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
def test_logit_bias(tokenize, openai_style):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
|
||||
|
||||
logit_bias = []
|
||||
if tokenize:
|
||||
res = server.make_request("POST", "/tokenize", data={
|
||||
"content": " " + " ".join(exclude) + " ",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
tokens = res.body["tokens"]
|
||||
logit_bias = [[tok, -100] for tok in tokens]
|
||||
|
||||
else:
|
||||
logit_bias = [[" " + tok + " ", -100] for tok in exclude]
|
||||
|
||||
if openai_style:
|
||||
logit_bias = {el[0]: -100 for el in logit_bias}
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": "What is the best book",
|
||||
"logit_bias": logit_bias,
|
||||
"temperature": 0.0
|
||||
})
|
||||
assert res.status_code == 200
|
||||
output_text = res.body["content"]
|
||||
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
||||
|
||||
|
||||
def test_cancel_request():
|
||||
global server
|
||||
server.n_ctx = 4096
|
||||
server.n_predict = -1
|
||||
server.n_slots = 1
|
||||
server.server_slots = True
|
||||
server.start()
|
||||
# send a request that will take a long time, but cancel it before it finishes
|
||||
try:
|
||||
server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, timeout=0.1)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
pass # expected
|
||||
# make sure the slot is free
|
||||
time.sleep(2)
|
||||
res = server.make_request("GET", "/slots")
|
||||
assert res.body[0]["is_processing"] == False
|
||||
|
||||
|
||||
# this test exercises the host-memory prompt cache
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/16391
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/17078
|
||||
def test_completion_prompt_cache():
|
||||
global server
|
||||
server.n_slots = 2
|
||||
server.kv_unified = True
|
||||
server.start()
|
||||
|
||||
for _ in range(16):
|
||||
# generate alternating random prompts with variable lengths in order to get them in and out of the cache
|
||||
r = random.randint(0, 4)
|
||||
prompt = (" Hello " + str(r)) * (40 + r)
|
||||
n_prompt = (40 + r)*5 + 2
|
||||
n_predict = random.randint(1, 8)
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/completion",
|
||||
data={
|
||||
"prompt": prompt,
|
||||
"n_predict": n_predict,
|
||||
},
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
content = res.body["content"]
|
||||
assert isinstance(content, str)
|
||||
assert len(content) > 0
|
||||
|
||||
assert type(res.body["has_new_line"]) == bool
|
||||
assert "timings" in res.body
|
||||
timings = res.body["timings"]
|
||||
|
||||
assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
|
||||
assert "predicted_n" in timings and timings["predicted_n"] == n_predict
|
||||
assert "tokens" in res.body and isinstance(res.body["tokens"], list)
|
||||
89
tools/server/tests/unit/test_ctx_shift.py
Normal file
89
tools/server/tests/unit/test_ctx_shift.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
SHORT_TEXT = """
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
""".strip()
|
||||
|
||||
LONG_TEXT = """
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||
""".strip()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.n_ctx = 512
|
||||
server.n_slots = 2
|
||||
server.n_predict = 128
|
||||
|
||||
|
||||
def test_ctx_shift_enabled():
|
||||
# the prompt is 226 tokens
|
||||
# the slot context is 512/2 = 256 tokens
|
||||
# 96 tokens are generated thanks to shifting the context when it gets full
|
||||
global server
|
||||
server.enable_ctx_shift = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 96,
|
||||
"prompt": SHORT_TEXT,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == 226
|
||||
assert res.body["timings"]["predicted_n"] == 96
|
||||
assert res.body["truncated"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
|
||||
(64, 64, False),
|
||||
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
|
||||
])
|
||||
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
||||
global server
|
||||
server.n_predict = -1
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": "Hi how are you",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["predicted_n"] == n_token_output
|
||||
assert res.body["truncated"] == truncated
|
||||
|
||||
|
||||
def test_ctx_shift_disabled_long_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
assert "exceeds the available context size" in res.body["error"]["message"]
|
||||
|
||||
def test_ctx_shift_disabled_stream():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/v1/completions", data={
|
||||
"n_predict": 256,
|
||||
"prompt": "Once",
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
if choice["finish_reason"] == "length":
|
||||
assert len(content) > 0
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["text"]
|
||||
291
tools/server/tests/unit/test_embedding.py
Normal file
291
tools/server/tests/unit/test_embedding.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import base64
|
||||
import struct
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.bert_bge_small()
|
||||
|
||||
EPSILON = 1e-3
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.bert_bge_small()
|
||||
|
||||
|
||||
def test_embedding_single():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "I believe the meaning of life is",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 1
|
||||
assert 'embedding' in res.body['data'][0]
|
||||
assert len(res.body['data'][0]['embedding']) > 1
|
||||
|
||||
# make sure embedding vector is normalized
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
|
||||
|
||||
|
||||
def test_embedding_multiple():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
"This is a test",
|
||||
"This is another test",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 4
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_multiple_with_fa():
|
||||
server = ServerPreset.bert_bge_small_with_fa()
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"a "*253,
|
||||
"b "*254,
|
||||
"c "*255,
|
||||
"d "*256,
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 4
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,is_multi_prompt",
|
||||
[
|
||||
# do not crash on empty input
|
||||
("", False),
|
||||
# single prompt
|
||||
("string", False),
|
||||
([12, 34, 56], False),
|
||||
([12, 34, "string", 56, 78], False),
|
||||
# multiple prompts
|
||||
(["string1", "string2"], True),
|
||||
(["string1", [12, 34, 56]], True),
|
||||
([[12, 34, 56], [12, 34, 56]], True),
|
||||
([[12, 34, 56], [12, "string", 34, 56]], True),
|
||||
]
|
||||
)
|
||||
def test_embedding_mixed_input(input, is_multi_prompt: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
|
||||
assert res.status_code == 200
|
||||
data = res.body['data']
|
||||
if is_multi_prompt:
|
||||
assert len(data) == len(input)
|
||||
for d in data:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
else:
|
||||
assert 'embedding' in data[0]
|
||||
assert len(data[0]['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_pooling_mean():
|
||||
global server
|
||||
server.pooling = 'mean'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "I believe the meaning of life is",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 1
|
||||
assert 'embedding' in res.body['data'][0]
|
||||
assert len(res.body['data'][0]['embedding']) > 1
|
||||
|
||||
# make sure embedding vector is normalized
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_mean_multiple():
|
||||
global server
|
||||
server.pooling = 'mean'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI",
|
||||
"This is a test",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 3
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_pooling_none():
|
||||
global server
|
||||
server.pooling = 'none'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert 'embedding' in res.body[0]
|
||||
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
||||
|
||||
# make sure embedding vector is not normalized
|
||||
for x in res.body[0]['embedding']:
|
||||
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_none_oai():
|
||||
global server
|
||||
server.pooling = 'none'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
|
||||
# /v1/embeddings does not support pooling type 'none'
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||
assert len(res.data) == 1
|
||||
assert len(res.data[0].embedding) > 1
|
||||
|
||||
|
||||
def test_embedding_openai_library_multiple():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
"This is a test",
|
||||
"This is another test",
|
||||
])
|
||||
assert len(res.data) == 4
|
||||
for d in res.data:
|
||||
assert len(d.embedding) > 1
|
||||
|
||||
|
||||
def test_embedding_error_prompt_too_long():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "This is a test " * 512,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "too large" in res.body["error"]["message"]
|
||||
|
||||
|
||||
def test_same_prompt_give_same_result():
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 5
|
||||
for i in range(1, len(res.body['data'])):
|
||||
v0 = res.body['data'][0]['embedding']
|
||||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,n_tokens",
|
||||
[
|
||||
("I believe the meaning of life is", 9),
|
||||
("This is a test", 6),
|
||||
]
|
||||
)
|
||||
def test_embedding_usage_single(content, n_tokens):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
def test_embedding_usage_multiple():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == 2 * 9
|
||||
|
||||
|
||||
def test_embedding_openai_library_base64():
|
||||
server.start()
|
||||
test_input = "Test base64 embedding output"
|
||||
|
||||
# get embedding in default format
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": test_input
|
||||
})
|
||||
assert res.status_code == 200
|
||||
vec0 = res.body["data"][0]["embedding"]
|
||||
|
||||
# get embedding in base64 format
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": test_input,
|
||||
"encoding_format": "base64"
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "data" in res.body
|
||||
assert len(res.body["data"]) == 1
|
||||
|
||||
embedding_data = res.body["data"][0]
|
||||
assert "embedding" in embedding_data
|
||||
assert isinstance(embedding_data["embedding"], str)
|
||||
|
||||
# Verify embedding is valid base64
|
||||
decoded = base64.b64decode(embedding_data["embedding"])
|
||||
# Verify decoded data can be converted back to float array
|
||||
float_count = len(decoded) // 4 # 4 bytes per float
|
||||
floats = struct.unpack(f'{float_count}f', decoded)
|
||||
assert len(floats) > 0
|
||||
assert all(isinstance(x, float) for x in floats)
|
||||
assert len(floats) == len(vec0)
|
||||
|
||||
# make sure the decoded data is the same as the original
|
||||
for x, y in zip(floats, vec0):
|
||||
assert abs(x - y) < EPSILON
|
||||
43
tools/server/tests/unit/test_ignore_eos.py
Normal file
43
tools/server/tests/unit/test_ignore_eos.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_ignore_eos_populates_logit_bias():
|
||||
"""ignore_eos=true must add EOG logit biases to generation_settings."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": True,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
# EOG token biases must be present with -inf bias
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) > 0
|
||||
for entry in logit_bias:
|
||||
assert entry["bias"] is None # null in JSON represents -inf
|
||||
|
||||
|
||||
def test_ignore_eos_false_no_logit_bias():
|
||||
"""ignore_eos=false (default) must NOT add EOG logit biases."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": False,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) == 0
|
||||
77
tools/server/tests/unit/test_infill.py
Normal file
77
tools/server/tests/unit/test_infill.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama_infill()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama_infill()
|
||||
|
||||
|
||||
def test_infill_without_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||
"prompt": " int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
def test_infill_with_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"input_extra": [{
|
||||
"filename": "llama.h",
|
||||
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
||||
}],
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||
"prompt": " int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_extra", [
|
||||
{},
|
||||
{"filename": "ok"},
|
||||
{"filename": 123},
|
||||
{"filename": 123, "text": "abc"},
|
||||
{"filename": 123, "text": 456},
|
||||
])
|
||||
def test_invalid_input_extra_req(input_extra):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"input_extra": [input_extra],
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||
"prompt": " int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||
def test_with_qwen_model():
|
||||
global server
|
||||
server.model_file = None
|
||||
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
|
||||
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
|
||||
server.start(timeout_seconds=600)
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"input_extra": [{
|
||||
"filename": "llama.h",
|
||||
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
||||
}],
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||
"prompt": " int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"
|
||||
115
tools/server/tests/unit/test_kv_keep_only_active.py
Normal file
115
tools/server/tests/unit/test_kv_keep_only_active.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
class LogReader:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.pos = 0
|
||||
def drain(self):
|
||||
with open(self.path) as f:
|
||||
f.seek(self.pos)
|
||||
content = f.read()
|
||||
self.pos = f.tell()
|
||||
return content
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.n_slots = 2
|
||||
server.n_predict = 4
|
||||
server.temperature = 0.0
|
||||
server.server_slots = True
|
||||
server.cache_ram = 100
|
||||
server.kv_unified = True
|
||||
server.debug = True
|
||||
fd, server.log_path = tempfile.mkstemp(suffix='.log')
|
||||
os.close(fd)
|
||||
yield
|
||||
|
||||
|
||||
LONG_PROMPT = (
|
||||
"Once upon a time in a land far away, there lived a brave knight "
|
||||
"who traveled across mountains and rivers to find the legendary "
|
||||
"golden sword hidden deep within the enchanted forest of whispers. "
|
||||
"He met many creatures along the way including dragons and fairies "
|
||||
"and wizards who helped him on his noble quest to save the kingdom."
|
||||
)
|
||||
|
||||
|
||||
# idle slot cleared on launch should restore from cache-ram
|
||||
def test_clear_and_restore():
|
||||
global server
|
||||
server.start()
|
||||
log = LogReader(server.log_path)
|
||||
|
||||
# verify feature is enabled
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" in log.drain()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT,
|
||||
"id_slot": 0,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
original_prompt_n = res.body["timings"]["prompt_n"]
|
||||
|
||||
# Slot 0 is the only slot with KV — should NOT be cleared
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
|
||||
# Launching slot 1 clears idle slot 0
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "The quick brown fox",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" in log.drain()
|
||||
|
||||
# Re-send same prompt — should restore from cache-ram
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "updating prompt cache" in log.drain()
|
||||
assert res.body["timings"]["cache_n"] > 0
|
||||
assert res.body["timings"]["prompt_n"] < original_prompt_n
|
||||
|
||||
# Follow-up — slot 0 kept its KV, no clearing needed
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT + " The knight finally reached the castle gates.",
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
|
||||
|
||||
def test_disabled_with_flag():
|
||||
global server
|
||||
server.no_cache_idle_slots = True
|
||||
server.start()
|
||||
log = LogReader(server.log_path)
|
||||
|
||||
# Feature should not be enabled
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" not in log.drain()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT,
|
||||
"id_slot": 0,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
# Request on different slot — should NOT trigger clearing
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "The quick brown fox",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
115
tools/server/tests/unit/test_lora.py
Normal file
115
tools/server/tests/unit/test_lora.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.stories15m_moe()
|
||||
server.lora_files = [download_file(LORA_FILE_URL)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale,re_content", [
|
||||
# without applying lora, the model should behave like a bedtime story generator
|
||||
(0.0, "(little|girl|three|years|old)+"),
|
||||
# with lora, the model should behave like a Shakespearean text generator
|
||||
(1.0, "(eye|love|glass|sun)+"),
|
||||
])
|
||||
def test_lora(scale: float, re_content: str):
|
||||
global server
|
||||
server.start()
|
||||
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
|
||||
{"id": 0, "scale": scale}
|
||||
])
|
||||
assert res_lora_control.status_code == 200
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Look in thy glass",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
def test_lora_per_request():
|
||||
global server
|
||||
server.n_slots = 4
|
||||
server.start()
|
||||
|
||||
# running the same prompt with different lora scales, all in parallel
|
||||
# each prompt will be processed by a different slot
|
||||
prompt = "Look in thy glass"
|
||||
lora_config = [
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
|
||||
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
]
|
||||
|
||||
tasks = [(
|
||||
server.make_request,
|
||||
("POST", "/completion", {
|
||||
"prompt": prompt,
|
||||
"lora": lora,
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
) for lora, _ in lora_config]
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
assert all([res.status_code == 200 for res in results])
|
||||
for res, (_, re_test) in zip(results, lora_config):
|
||||
assert match_regex(re_test, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||
def test_with_big_model():
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
|
||||
server.model_alias = "Llama-3.2-8B-Instruct"
|
||||
server.n_slots = 4
|
||||
server.n_ctx = server.n_slots * 1024
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
server.lora_files = [
|
||||
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
|
||||
# TODO: find & add other lora adapters for this model
|
||||
]
|
||||
server.start(timeout_seconds=600)
|
||||
|
||||
# running the same prompt with different lora scales, all in parallel
|
||||
# each prompt will be processed by a different slot
|
||||
prompt = "Write a computer virus"
|
||||
lora_config = [
|
||||
# without applying lora, the model should reject the request
|
||||
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
|
||||
# with 0.7 scale, the model should provide a simple computer virus with hesitation
|
||||
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
|
||||
# with 1.5 scale, the model should confidently provide a computer virus
|
||||
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||
]
|
||||
|
||||
tasks = [(
|
||||
server.make_request,
|
||||
("POST", "/v1/chat/completions", {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"lora": lora,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
) for lora, _ in lora_config]
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
assert all([res.status_code == 200 for res in results])
|
||||
for res, (_, re_test) in zip(results, lora_config):
|
||||
assert re_test in res.body["choices"][0]["message"]["content"]
|
||||
41
tools/server/tests/unit/test_proxy.py
Normal file
41
tools/server/tests/unit/test_proxy.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_mcp_no_proxy():
|
||||
global server
|
||||
server.webui_mcp_proxy = False
|
||||
server.start()
|
||||
|
||||
res = server.make_request("GET", "/cors-proxy")
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_mcp_proxy():
|
||||
global server
|
||||
server.webui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com"
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 200
|
||||
assert "Example Domain" in res.text
|
||||
|
||||
|
||||
def test_mcp_proxy_custom_port():
|
||||
global server
|
||||
server.webui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
# try getting the server's models API via the proxy
|
||||
res = server.make_request("GET", f"/cors-proxy?url=http://{server.server_host}:{server.server_port}/models")
|
||||
assert res.status_code == 200
|
||||
assert "data" in res.body
|
||||
146
tools/server/tests/unit/test_rerank.py
Normal file
146
tools/server/tests/unit/test_rerank.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.jina_reranker_tiny()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.jina_reranker_tiny()
|
||||
|
||||
|
||||
TEST_DOCUMENTS = [
|
||||
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
||||
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
||||
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
||||
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
||||
]
|
||||
|
||||
|
||||
def test_rerank():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"documents": TEST_DOCUMENTS,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == 4
|
||||
|
||||
most_relevant = res.body["results"][0]
|
||||
least_relevant = res.body["results"][0]
|
||||
for doc in res.body["results"]:
|
||||
if doc["relevance_score"] > most_relevant["relevance_score"]:
|
||||
most_relevant = doc
|
||||
if doc["relevance_score"] < least_relevant["relevance_score"]:
|
||||
least_relevant = doc
|
||||
|
||||
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
||||
assert most_relevant["index"] == 2
|
||||
assert least_relevant["index"] == 3
|
||||
|
||||
|
||||
def test_rerank_tei_format():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"texts": TEST_DOCUMENTS,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == 4
|
||||
|
||||
most_relevant = res.body[0]
|
||||
least_relevant = res.body[0]
|
||||
for doc in res.body:
|
||||
if doc["score"] > most_relevant["score"]:
|
||||
most_relevant = doc
|
||||
if doc["score"] < least_relevant["score"]:
|
||||
least_relevant = doc
|
||||
|
||||
assert most_relevant["score"] > least_relevant["score"]
|
||||
assert most_relevant["index"] == 2
|
||||
assert least_relevant["index"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("documents", [
|
||||
[],
|
||||
None,
|
||||
123,
|
||||
[1, 2, 3],
|
||||
])
|
||||
def test_invalid_rerank_req(documents):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"documents": documents,
|
||||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,doc1,doc2,n_tokens",
|
||||
[
|
||||
("Machine learning is", "A machine", "Learning is", 19),
|
||||
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
|
||||
]
|
||||
)
|
||||
def test_rerank_usage(query, doc1, doc2, n_tokens):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": query,
|
||||
"documents": [
|
||||
doc1,
|
||||
doc2,
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"documents": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == expected_len
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_tei_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"texts": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == expected_len
|
||||
255
tools/server/tests/unit/test_router.py
Normal file
255
tools/server/tests/unit/test_router.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.router()
|
||||
|
||||
|
||||
def test_router_props():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["role"] == "router"
|
||||
assert res.body["max_instances"] == 2
|
||||
assert res.body["models_autoload"] is False
|
||||
assert res.body["build_info"].startswith("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,success",
|
||||
[
|
||||
("ggml-org/tinygemma3-GGUF:Q8_0", True),
|
||||
("non-existent/model", False),
|
||||
]
|
||||
)
|
||||
def test_router_chat_completion_stream(model: str, success: bool):
|
||||
global server
|
||||
server.start()
|
||||
content = ""
|
||||
ex: ServerError | None = None
|
||||
try:
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": 16,
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
"stream": True,
|
||||
})
|
||||
for data in res:
|
||||
if data["choices"]:
|
||||
choice = data["choices"][0]
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert "content" not in choice["delta"]
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"] or ''
|
||||
except ServerError as e:
|
||||
ex = e
|
||||
|
||||
if success:
|
||||
assert ex is None
|
||||
assert len(content) > 0
|
||||
else:
|
||||
assert ex is not None
|
||||
assert content == ""
|
||||
|
||||
|
||||
def _get_model_ids(is_reload: bool) -> set[str]:
|
||||
res = server.make_request("GET", "/models" + ("?reload=1" if is_reload else ""))
|
||||
assert res.status_code == 200
|
||||
return {item["id"] for item in res.body.get("data", [])}
|
||||
|
||||
|
||||
def _get_model_status(model_id: str) -> str:
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
for item in res.body.get("data", []):
|
||||
if item.get("id") == model_id or item.get("model") == model_id:
|
||||
return item["status"]["value"]
|
||||
raise AssertionError(f"Model {model_id} not found in /models response")
|
||||
|
||||
|
||||
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
|
||||
deadline = time.time() + timeout
|
||||
last_status = None
|
||||
while time.time() < deadline:
|
||||
last_status = _get_model_status(model_id)
|
||||
if last_status in desired:
|
||||
return last_status
|
||||
time.sleep(1)
|
||||
raise AssertionError(
|
||||
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
|
||||
)
|
||||
|
||||
|
||||
def _load_model_and_wait(
|
||||
model_id: str, timeout: int = 60, headers: dict | None = None
|
||||
) -> None:
|
||||
load_res = server.make_request(
|
||||
"POST", "/models/load", data={"model": model_id}, headers=headers
|
||||
)
|
||||
assert load_res.status_code == 200
|
||||
assert isinstance(load_res.body, dict)
|
||||
assert load_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
|
||||
|
||||
|
||||
def test_router_unload_model():
|
||||
global server
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
|
||||
assert unload_res.status_code == 200
|
||||
assert unload_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"unloaded"})
|
||||
|
||||
|
||||
def test_router_models_max_evicts_lru():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.start()
|
||||
|
||||
candidate_models = [
|
||||
"ggml-org/tinygemma3-GGUF:Q8_0",
|
||||
"ggml-org/test-model-stories260K:F32",
|
||||
"ggml-org/test-model-stories260K-infill:F32",
|
||||
]
|
||||
|
||||
# Load only the first 2 models to fill the cache
|
||||
first, second, third = candidate_models[:3]
|
||||
|
||||
_load_model_and_wait(first, timeout=120)
|
||||
_load_model_and_wait(second, timeout=120)
|
||||
|
||||
# Verify both models are loaded
|
||||
assert _get_model_status(first) == "loaded"
|
||||
assert _get_model_status(second) == "loaded"
|
||||
|
||||
# Load the third model - this should trigger LRU eviction of the first model
|
||||
_load_model_and_wait(third, timeout=120)
|
||||
|
||||
# Verify eviction: third is loaded, first was evicted
|
||||
assert _get_model_status(third) == "loaded"
|
||||
assert _get_model_status(first) == "unloaded"
|
||||
|
||||
|
||||
def test_router_no_models_autoload():
|
||||
global server
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
success_res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert success_res.status_code == 200
|
||||
assert "error" not in success_res.body
|
||||
|
||||
|
||||
def test_router_api_key_required():
|
||||
global server
|
||||
server.api_key = "sk-router-secret"
|
||||
server.start()
|
||||
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
assert res.body.get("error", {}).get("type") == "authentication_error"
|
||||
|
||||
_load_model_and_wait(model_id, headers=auth_headers)
|
||||
|
||||
authed = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers=auth_headers,
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert authed.status_code == 200
|
||||
assert "error" not in authed.body
|
||||
|
||||
|
||||
def test_router_reload_models():
|
||||
"""POST /models/reload re-reads the INI preset and updates the model list."""
|
||||
global server
|
||||
|
||||
preset_path = os.path.join(TMP_DIR, "test_reload.ini")
|
||||
|
||||
# Initial preset: two models
|
||||
with open(preset_path, "w") as f:
|
||||
f.write(
|
||||
"[model-reload-a]\n"
|
||||
"hf-repo = ggml-org/test-model-stories260K\n"
|
||||
"\n"
|
||||
"[model-reload-b]\n"
|
||||
"hf-repo = ggml-org/test-model-stories260K-infill\n"
|
||||
)
|
||||
|
||||
server.models_preset = preset_path
|
||||
server.start()
|
||||
|
||||
ids = _get_model_ids(is_reload=False)
|
||||
assert "model-reload-a" in ids
|
||||
assert "model-reload-b" in ids
|
||||
|
||||
# Updated preset: remove a, keep b unchanged, add c
|
||||
with open(preset_path, "w") as f:
|
||||
f.write(
|
||||
"[model-reload-b]\n"
|
||||
"hf-repo = ggml-org/test-model-stories260K-infill\n"
|
||||
"\n"
|
||||
"[model-reload-c]\n"
|
||||
"hf-repo = ggml-org/test-model-stories260K\n"
|
||||
)
|
||||
|
||||
try:
|
||||
ids = _get_model_ids(is_reload=True)
|
||||
assert "model-reload-a" not in ids, "removed model should no longer appear"
|
||||
assert "model-reload-b" in ids, "unchanged model should still appear"
|
||||
assert "model-reload-c" in ids, "newly added model should appear"
|
||||
finally:
|
||||
os.remove(preset_path)
|
||||
136
tools/server/tests/unit/test_security.py
Normal file
136
tools/server/tests/unit/test_security.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
TEST_API_KEY = "sk-this-is-the-secret-key"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.api_key = TEST_API_KEY
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
|
||||
def test_access_public_endpoint(endpoint: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", endpoint)
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
|
||||
|
||||
def test_access_static_assets_without_api_key():
|
||||
"""Static web UI assets should not require API key authentication (issue #21229)"""
|
||||
global server
|
||||
server.start()
|
||||
for path in ["/", "/bundle.js", "/bundle.css"]:
|
||||
res = server.make_request("GET", path)
|
||||
assert res.status_code == 200, f"Expected 200 for {path}, got {res.status_code}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
|
||||
def test_incorrect_api_key(api_key: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"Authorization": f"Bearer {api_key}" if api_key else None,
|
||||
})
|
||||
assert res.status_code == 401
|
||||
assert "error" in res.body
|
||||
assert res.body["error"]["type"] == "authentication_error"
|
||||
|
||||
|
||||
def test_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_correct_api_key_anthropic_header():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"X-Api-Key": TEST_API_KEY,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_openai_library_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a chatbot."},
|
||||
{"role": "user", "content": "What is the meaning of life?"},
|
||||
],
|
||||
)
|
||||
assert len(res.choices) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
|
||||
("localhost", "Access-Control-Allow-Origin", "localhost"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
|
||||
("origin", "Access-Control-Allow-Credentials", "true"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
|
||||
])
|
||||
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("OPTIONS", "/completions", headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Authorization",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert cors_header in res.headers
|
||||
assert res.headers[cors_header] == cors_header_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"media_path, image_url, success",
|
||||
[
|
||||
(None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
|
||||
("../../../tools", "file://mtmd/test-1.jpeg", True),
|
||||
("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
|
||||
("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
|
||||
("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
|
||||
]
|
||||
)
|
||||
def test_local_media_file(media_path, image_url, success,):
|
||||
server = ServerPreset.tinygemma3()
|
||||
server.media_path = media_path
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 1,
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "test"},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": image_url,
|
||||
}},
|
||||
]},
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
else:
|
||||
assert res.status_code == 400
|
||||
39
tools/server/tests/unit/test_sleep.py
Normal file
39
tools/server/tests/unit/test_sleep.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import time
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_server_sleep():
|
||||
global server
|
||||
server.sleep_idle_seconds = 1
|
||||
server.start()
|
||||
|
||||
# wait a bit so that server can go to sleep
|
||||
time.sleep(2)
|
||||
|
||||
# make sure these endpoints are still responsive after sleep
|
||||
res = server.make_request("GET", "/health")
|
||||
assert res.status_code == 200
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["is_sleeping"] == True
|
||||
|
||||
# make a generation request to wake up the server
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 1,
|
||||
"prompt": "Hello",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
# it should no longer be sleeping
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["is_sleeping"] == False
|
||||
98
tools/server/tests/unit/test_slot_save.py
Normal file
98
tools/server/tests/unit/test_slot_save.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.slot_save_path = "./tmp"
|
||||
server.temperature = 0.0
|
||||
|
||||
|
||||
def test_slot_save_restore():
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# First prompt in slot 1 should be fully processed
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
|
||||
# Save state of slot 1
|
||||
res = server.make_request("POST", "/slots/1?action=save", data={
|
||||
"filename": "slot1.bin",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["n_saved"] == 84
|
||||
|
||||
# Since we have cache, this should only process the last tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
|
||||
|
||||
# Loading the saved cache into slot 0
|
||||
res = server.make_request("POST", "/slots/0?action=restore", data={
|
||||
"filename": "slot1.bin",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["n_restored"] == 84
|
||||
|
||||
# Since we have cache, slot 0 should only process the last tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 0,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
|
||||
|
||||
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 1
|
||||
|
||||
|
||||
def test_slot_erase():
|
||||
global server
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
|
||||
# erase slot 1
|
||||
res = server.make_request("POST", "/slots/1?action=erase")
|
||||
assert res.status_code == 200
|
||||
|
||||
# re-run the same prompt, it should process all tokens again
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
131
tools/server/tests/unit/test_speculative.py
Normal file
131
tools/server/tests/unit/test_speculative.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
# We use a F16 MOE gguf as main model, and q4_0 as draft model
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
|
||||
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.stories15m_moe()
|
||||
# set default values
|
||||
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
|
||||
server.draft_min = 4
|
||||
server.draft_max = 8
|
||||
server.fa = "off"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fixture_create_server():
|
||||
return create_server()
|
||||
|
||||
|
||||
def test_with_and_without_draft():
|
||||
global server
|
||||
server.model_draft = None # disable draft model
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"n_predict": 16,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
content_no_draft = res.body["content"]
|
||||
server.stop()
|
||||
|
||||
# create new server with draft model
|
||||
create_server()
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"n_predict": 16,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
content_draft = res.body["content"]
|
||||
|
||||
assert content_no_draft == content_draft
|
||||
|
||||
|
||||
def test_different_draft_min_draft_max():
|
||||
global server
|
||||
test_values = [
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(4, 8),
|
||||
(4, 12),
|
||||
(8, 16),
|
||||
]
|
||||
last_content = None
|
||||
for draft_min, draft_max in test_values:
|
||||
server.stop()
|
||||
server.draft_min = draft_min
|
||||
server.draft_max = draft_max
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"n_predict": 16,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
if last_content is not None:
|
||||
assert last_content == res.body["content"]
|
||||
last_content = res.body["content"]
|
||||
|
||||
|
||||
def test_slot_ctx_not_exceeded():
|
||||
global server
|
||||
server.n_ctx = 256
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Hello " * 248,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"speculative.p_min": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_with_ctx_shift():
|
||||
global server
|
||||
server.n_ctx = 256
|
||||
server.enable_ctx_shift = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Hello " * 248,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"n_predict": 256,
|
||||
"speculative.p_min": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["tokens_predicted"] == 256
|
||||
assert res.body["truncated"] == True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
])
|
||||
def test_multi_requests_parallel(n_slots: int, n_requests: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.start()
|
||||
tasks = []
|
||||
for _ in range(n_requests):
|
||||
tasks.append((server.make_request, ("POST", "/completion", {
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
})))
|
||||
results = parallel_function_calls(tasks)
|
||||
for res in results:
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])
|
||||
106
tools/server/tests/unit/test_template.py
Normal file
106
tools/server/tests/unit/test_template.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from unit.test_tool_call import TEST_TOOL
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
import datetime
|
||||
from utils import *
|
||||
from typing import Literal
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,reasoning,expected_end", [
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", "on", "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B","auto", "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", "off", "<think>\n</think>"),
|
||||
|
||||
("Qwen-Qwen3-0.6B","auto", "<|im_start|>assistant\n"),
|
||||
("Qwen-Qwen3-0.6B", "off", "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
|
||||
|
||||
("Qwen-QwQ-32B","auto", "<|im_start|>assistant\n<think>\n"),
|
||||
("Qwen-QwQ-32B", "off", "<|im_start|>assistant\n<think>\n</think>"),
|
||||
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use","auto", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", "off", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
|
||||
])
|
||||
def test_reasoning(template_name: str, reasoning: Literal['on', 'off', 'auto'] | None, expected_end: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.reasoning = reasoning
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
|
||||
])
|
||||
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_generation_prompt", [False, True])
|
||||
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
|
||||
])
|
||||
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
if add_generation_prompt:
|
||||
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
else:
|
||||
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
59
tools/server/tests/unit/test_tokenize.py
Normal file
59
tools/server/tests/unit/test_tokenize.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_tokenize_detokenize():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content
|
||||
})
|
||||
assert res_tok.status_code == 200
|
||||
assert len(res_tok.body["tokens"]) > 5
|
||||
# detokenize
|
||||
res_detok = server.make_request("POST", "/detokenize", data={
|
||||
"tokens": res_tok.body["tokens"],
|
||||
})
|
||||
assert res_detok.status_code == 200
|
||||
assert res_detok.body["content"].strip() == content
|
||||
|
||||
|
||||
def test_tokenize_with_bos():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
bosId = 1
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"add_special": True,
|
||||
})
|
||||
assert res_tok.status_code == 200
|
||||
assert res_tok.body["tokens"][0] == bosId
|
||||
|
||||
|
||||
def test_tokenize_with_pieces():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "This is a test string with unicode 媽 and emoji 🤗"
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"with_pieces": True,
|
||||
})
|
||||
assert res_tok.status_code == 200
|
||||
for token in res_tok.body["tokens"]:
|
||||
assert "id" in token
|
||||
assert token["id"] > 0
|
||||
assert "piece" in token
|
||||
assert len(token["piece"]) > 0
|
||||
645
tools/server/tests/unit/test_tool_call.py
Executable file
645
tools/server/tests/unit/test_tool_call.py
Executable file
@@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
from utils import *
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
|
||||
TIMEOUT_HTTP_REQUEST = 60
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-tool-call"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
server.n_ctx = 8192
|
||||
server.n_batch = 2048
|
||||
|
||||
class CompletionMode(Enum):
|
||||
NORMAL = "normal"
|
||||
STREAMED = "streamed"
|
||||
|
||||
class ToolParameters(TypedDict):
|
||||
type: str
|
||||
properties: dict[str, dict]
|
||||
required: list[str]
|
||||
|
||||
class ToolFunction(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
parameters: ToolParameters
|
||||
|
||||
class ToolDefinition(TypedDict):
|
||||
type: str
|
||||
function: ToolFunction
|
||||
|
||||
TEST_TOOL = ToolDefinition(
|
||||
type = "function",
|
||||
function = ToolFunction(
|
||||
name = "test",
|
||||
description = "",
|
||||
parameters = ToolParameters(
|
||||
type = "object",
|
||||
properties = {
|
||||
"success": {
|
||||
"type": "boolean",
|
||||
"const": True,
|
||||
},
|
||||
},
|
||||
required = ["success"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
PYTHON_TOOL = ToolDefinition(
|
||||
type = "function",
|
||||
function = ToolFunction(
|
||||
name = "python",
|
||||
description = "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
parameters = ToolParameters(
|
||||
type = "object",
|
||||
properties = {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to run in the ipython interpreter.",
|
||||
},
|
||||
},
|
||||
required = ["code"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
WEATHER_TOOL = ToolDefinition(
|
||||
type = "function",
|
||||
function = ToolFunction(
|
||||
name = "get_current_weather",
|
||||
description = "Get the current weather in a given location",
|
||||
parameters = ToolParameters(
|
||||
type = "object",
|
||||
properties = {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'",
|
||||
},
|
||||
},
|
||||
required = ["location"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
**kwargs,
|
||||
})
|
||||
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"], f'Expected tool name to be {tool_call["function"]["name"]} in {choice["message"]}'
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, dict) or isinstance(actual_arguments, str), f'Expected arguments to be a dict or str, got: {actual_arguments}'
|
||||
if argument_key is not None:
|
||||
if (isinstance(actual_arguments, str)):
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {actual_arguments}, expected: {argument_key}"
|
||||
|
||||
# PR #22654: commented out since we're now allowing content before tool calls in tool_call: required, so we can't force this
|
||||
# in the tiny model just by using the grammar
|
||||
#
|
||||
# @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
# @pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
# ("Qwen3-Coder", TEST_TOOL, "success"),
|
||||
# ("Qwen3-Coder", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
# ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
# ])
|
||||
# def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
# global server
|
||||
# n_predict = 1024
|
||||
# # server = ServerPreset.stories15m_moe()
|
||||
# server.jinja = True
|
||||
# server.n_predict = n_predict
|
||||
# server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
# server.start()
|
||||
# do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||
|
||||
# @pytest.mark.slow
|
||||
# @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
# @pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
# ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
# ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
# # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
|
||||
# # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
# ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
# ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
# ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
# ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
# ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
# # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
|
||||
# # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
|
||||
# ])
|
||||
# def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
# global server
|
||||
# n_predict = 512
|
||||
# # server = ServerPreset.stories15m_moe()
|
||||
# server.jinja = True
|
||||
# server.n_predict = n_predict
|
||||
# server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
# server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||
# do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
assert isinstance(actual_arguments, str)
|
||||
if argument_key is not None:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
],
|
||||
"tools": tools if tools else None,
|
||||
"tool_choice": tool_choice,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
|
||||
("meetkai-functionary-medium-v3.1", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
|
||||
("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
|
||||
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start()
|
||||
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_weather(server: ServerProcess, **kwargs):
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||
assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
|
||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_6789",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "calculate",
|
||||
"content": "0.55644242476",
|
||||
"tool_call_id": "call_6789"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"calculate",
|
||||
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"expression":{
|
||||
"type":"string",
|
||||
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
|
||||
}
|
||||
},
|
||||
"required":["expression"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
assert content is not None, f'Expected content in {choice["message"]}'
|
||||
if result_override is not None:
|
||||
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
|
||||
else:
|
||||
assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
|
||||
f'Expected something like "The y coordinate is 0.56.", got {content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start()
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
],
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
if expect_content is None:
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
|
||||
|
||||
reasoning_content = choice["message"].get("reasoning_content")
|
||||
if expect_reasoning_content is None:
|
||||
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
|
||||
else:
|
||||
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
# ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
])
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = None
|
||||
if isinstance(template_override, tuple):
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||
|
||||
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling agent."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
],
|
||||
"tools": [PYTHON_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
|
||||
161
tools/server/tests/unit/test_vision_api.py
Normal file
161
tools/server/tests/unit/test_vision_api.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
import base64
|
||||
import requests
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
def get_img_url(id: str) -> str:
|
||||
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
||||
if id == "IMG_URL_0":
|
||||
return IMG_URL_0
|
||||
elif id == "IMG_URL_1":
|
||||
return IMG_URL_1
|
||||
elif id == "IMG_BASE64_URI_0":
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
elif id == "IMG_BASE64_0":
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
elif id == "IMG_BASE64_URI_1":
|
||||
response = requests.get(IMG_URL_1)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
elif id == "IMG_BASE64_1":
|
||||
response = requests.get(IMG_URL_1)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
else:
|
||||
return id
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
os.environ['LLAMA_MEDIA_MARKER'] = '<__media__>'
|
||||
server = ServerPreset.tinygemma3()
|
||||
|
||||
def test_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
def test_v1_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/v1/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_url, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
|
||||
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
|
||||
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
|
||||
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
|
||||
("What is this:\n", "malformed", False, None),
|
||||
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
||||
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
||||
# TODO @ngxson : test with multiple images, no images and with audio
|
||||
]
|
||||
)
|
||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": get_img_url(image_url),
|
||||
}},
|
||||
]},
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
|
||||
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
|
||||
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
||||
("What is this:\n", "", False, None), # empty string
|
||||
]
|
||||
)
|
||||
def test_vision_completion(prompt, image_data, success, re_content):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"prompt": {
|
||||
JSON_PROMPT_STRING_KEY: prompt,
|
||||
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
|
||||
},
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body["content"]
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", "IMG_BASE64_0", True),
|
||||
("What is this: <__media__>\n", "IMG_BASE64_1", True),
|
||||
("What is this: <__media__>\n", "malformed", False), # non-image data
|
||||
("What is this:\n", "base64", False), # non-image data
|
||||
]
|
||||
)
|
||||
def test_vision_embeddings(prompt, image_data, success):
|
||||
global server
|
||||
server.server_embeddings = True
|
||||
server.n_batch = 512
|
||||
server.start()
|
||||
image_data = get_img_url(image_data)
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"content": [
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, },
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body
|
||||
# Ensure embeddings are stable when multimodal.
|
||||
assert content[0]['embedding'] == content[1]['embedding']
|
||||
# Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
|
||||
assert content[0]['embedding'] != content[2]['embedding']
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
682
tools/server/tests/utils.py
Normal file
682
tools/server/tests/utils.py
Normal file
@@ -0,0 +1,682 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# type: ignore[reportUnusedImport]
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
||||
import re
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
import sys
|
||||
import requests
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Tuple,
|
||||
Set,
|
||||
)
|
||||
from re import RegexFlag
|
||||
import wget
|
||||
|
||||
|
||||
DEFAULT_HTTP_TIMEOUT = 60
|
||||
|
||||
|
||||
class ServerResponse:
|
||||
headers: dict
|
||||
status_code: int
|
||||
body: dict | Any
|
||||
|
||||
|
||||
class ServerError(Exception):
|
||||
def __init__(self, code, body):
|
||||
self.code = code
|
||||
self.body = body
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
# default options
|
||||
debug: bool = False
|
||||
server_port: int = 8080
|
||||
server_host: str = "127.0.0.1"
|
||||
model_hf_repo: str | None = "ggml-org/models"
|
||||
model_hf_file: str | None = "tinyllamas/stories260K.gguf"
|
||||
model_alias: str = "tinyllama-2"
|
||||
temperature: float = 0.8
|
||||
seed: int = 42
|
||||
offline: bool = False
|
||||
|
||||
# custom options
|
||||
model_alias: str | None = None
|
||||
model_tags: str | None = None
|
||||
model_url: str | None = None
|
||||
model_file: str | None = None
|
||||
model_draft: str | None = None
|
||||
n_threads: int | None = None
|
||||
n_gpu_layer: int | None = None
|
||||
n_batch: int | None = None
|
||||
n_ubatch: int | None = None
|
||||
n_ctx: int | None = None
|
||||
n_ga: int | None = None
|
||||
n_ga_w: int | None = None
|
||||
n_predict: int | None = None
|
||||
n_prompts: int | None = 0
|
||||
slot_save_path: str | None = None
|
||||
id_slot: int | None = None
|
||||
cache_prompt: bool | None = None
|
||||
n_slots: int | None = None
|
||||
ctk: str | None = None
|
||||
ctv: str | None = None
|
||||
fa: str | None = None
|
||||
server_continuous_batching: bool | None = False
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
kv_unified: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
api_key: str | None = None
|
||||
models_dir: str | None = None
|
||||
models_max: int | None = None
|
||||
models_preset: str | None = None
|
||||
no_models_autoload: bool | None = None
|
||||
lora_files: List[str] | None = None
|
||||
enable_ctx_shift: int | None = False
|
||||
spec_draft_n_min: int | None = None
|
||||
spec_draft_n_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
|
||||
reasoning: Literal['on', 'off', 'auto'] | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
mmproj_url: str | None = None
|
||||
media_path: str | None = None
|
||||
sleep_idle_seconds: int | None = None
|
||||
cache_ram: int | None = None
|
||||
no_cache_idle_slots: bool = False
|
||||
log_path: str | None = None
|
||||
webui_mcp_proxy: bool = False
|
||||
backend_sampling: bool = False
|
||||
gcp_compat: bool = False
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
|
||||
def __init__(self):
|
||||
if "N_GPU_LAYERS" in os.environ:
|
||||
self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
|
||||
if "DEBUG" in os.environ:
|
||||
self.debug = True
|
||||
if "PORT" in os.environ:
|
||||
self.server_port = int(os.environ["PORT"])
|
||||
self.external_server = "DEBUG_EXTERNAL" in os.environ
|
||||
|
||||
def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||
env = {**os.environ}
|
||||
if "LLAMA_CACHE" not in os.environ:
|
||||
env["LLAMA_CACHE"] = "tmp"
|
||||
if self.external_server:
|
||||
print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
|
||||
return
|
||||
if self.server_path is not None:
|
||||
server_path = self.server_path
|
||||
elif "LLAMA_SERVER_BIN_PATH" in os.environ:
|
||||
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
|
||||
elif os.name == "nt":
|
||||
server_path = "../../../build/bin/Release/llama-server.exe"
|
||||
else:
|
||||
server_path = "../../../build/bin/llama-server"
|
||||
server_args = [
|
||||
"--host",
|
||||
self.server_host,
|
||||
"--port",
|
||||
self.server_port,
|
||||
"--temp",
|
||||
self.temperature,
|
||||
"--seed",
|
||||
self.seed,
|
||||
]
|
||||
if self.offline:
|
||||
server_args.append("--offline")
|
||||
if self.model_file:
|
||||
server_args.extend(["--model", self.model_file])
|
||||
if self.model_url:
|
||||
server_args.extend(["--model-url", self.model_url])
|
||||
if self.model_draft:
|
||||
server_args.extend(["--model-draft", self.model_draft])
|
||||
if self.model_hf_repo:
|
||||
server_args.extend(["--hf-repo", self.model_hf_repo])
|
||||
if self.model_hf_file:
|
||||
server_args.extend(["--hf-file", self.model_hf_file])
|
||||
if self.models_dir:
|
||||
server_args.extend(["--models-dir", self.models_dir])
|
||||
if self.models_max is not None:
|
||||
server_args.extend(["--models-max", self.models_max])
|
||||
if self.models_preset:
|
||||
server_args.extend(["--models-preset", self.models_preset])
|
||||
if self.n_batch:
|
||||
server_args.extend(["--batch-size", self.n_batch])
|
||||
if self.n_ubatch:
|
||||
server_args.extend(["--ubatch-size", self.n_ubatch])
|
||||
if self.n_threads:
|
||||
server_args.extend(["--threads", self.n_threads])
|
||||
if self.n_gpu_layer:
|
||||
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
|
||||
if self.server_continuous_batching:
|
||||
server_args.append("--cont-batching")
|
||||
if self.server_embeddings:
|
||||
server_args.append("--embedding")
|
||||
if self.server_reranking:
|
||||
server_args.append("--reranking")
|
||||
if self.server_metrics:
|
||||
server_args.append("--metrics")
|
||||
if self.kv_unified:
|
||||
server_args.append("--kv-unified")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
else:
|
||||
server_args.append("--no-slots")
|
||||
if self.pooling:
|
||||
server_args.extend(["--pooling", self.pooling])
|
||||
if self.model_alias:
|
||||
server_args.extend(["--alias", self.model_alias])
|
||||
if self.model_tags:
|
||||
server_args.extend(["--tags", self.model_tags])
|
||||
if self.n_ctx:
|
||||
server_args.extend(["--ctx-size", self.n_ctx])
|
||||
if self.n_slots:
|
||||
server_args.extend(["--parallel", self.n_slots])
|
||||
if self.ctk:
|
||||
server_args.extend(["-ctk", self.ctk])
|
||||
if self.ctv:
|
||||
server_args.extend(["-ctv", self.ctv])
|
||||
if self.fa is not None:
|
||||
server_args.extend(["-fa", self.fa])
|
||||
if self.n_predict:
|
||||
server_args.extend(["--n-predict", self.n_predict])
|
||||
if self.slot_save_path:
|
||||
server_args.extend(["--slot-save-path", self.slot_save_path])
|
||||
if self.n_ga:
|
||||
server_args.extend(["--grp-attn-n", self.n_ga])
|
||||
if self.n_ga_w:
|
||||
server_args.extend(["--grp-attn-w", self.n_ga_w])
|
||||
if self.debug:
|
||||
server_args.append("--verbose")
|
||||
if self.lora_files:
|
||||
for lora_file in self.lora_files:
|
||||
server_args.extend(["--lora", lora_file])
|
||||
if self.enable_ctx_shift:
|
||||
server_args.append("--context-shift")
|
||||
if self.api_key:
|
||||
server_args.extend(["--api-key", self.api_key])
|
||||
if self.spec_draft_n_max:
|
||||
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
|
||||
if self.spec_draft_n_min:
|
||||
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.no_models_autoload:
|
||||
server_args.append("--no-models-autoload")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
else:
|
||||
server_args.append("--no-jinja")
|
||||
if self.reasoning_format is not None:
|
||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||
if self.reasoning is not None:
|
||||
server_args.extend(("--reasoning", self.reasoning))
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||
if self.mmproj_url:
|
||||
server_args.extend(["--mmproj-url", self.mmproj_url])
|
||||
if self.media_path:
|
||||
server_args.extend(["--media-path", self.media_path])
|
||||
if self.sleep_idle_seconds is not None:
|
||||
server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
|
||||
if self.cache_ram is not None:
|
||||
server_args.extend(["--cache-ram", self.cache_ram])
|
||||
if self.no_cache_idle_slots:
|
||||
server_args.append("--no-cache-idle-slots")
|
||||
if self.webui_mcp_proxy:
|
||||
server_args.append("--webui-mcp-proxy")
|
||||
if self.backend_sampling:
|
||||
server_args.append("--backend_sampling")
|
||||
if self.gcp_compat:
|
||||
env["AIP_MODE"] = "PREDICTION"
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"tests: starting server with: {' '.join(args)}")
|
||||
|
||||
flags = 0
|
||||
if "nt" == os.name:
|
||||
flags |= subprocess.DETACHED_PROCESS
|
||||
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
flags |= subprocess.CREATE_NO_WINDOW
|
||||
|
||||
if self.log_path:
|
||||
self._log = open(self.log_path, "w")
|
||||
else:
|
||||
self._log = sys.stdout
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
[str(arg) for arg in [server_path, *server_args]],
|
||||
creationflags=flags,
|
||||
stdout=self._log,
|
||||
stderr=self._log if self._log != sys.stdout else sys.stdout,
|
||||
env=env,
|
||||
)
|
||||
server_instances.add(self)
|
||||
|
||||
print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
|
||||
|
||||
# wait for server to start
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
response = self.make_request("GET", "/health", headers={
|
||||
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
||||
})
|
||||
if response.status_code == 200:
|
||||
self.ready = True
|
||||
return # server is ready
|
||||
except Exception as e:
|
||||
pass
|
||||
# Check if process died
|
||||
if self.process.poll() is not None:
|
||||
raise RuntimeError(f"Server process died with return code {self.process.returncode}")
|
||||
|
||||
print(f"Waiting for server to start...")
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.external_server:
|
||||
print("[external_server]: Not stopping external server")
|
||||
return
|
||||
if self in server_instances:
|
||||
server_instances.remove(self)
|
||||
if self.process:
|
||||
print(f"Stopping server with pid={self.process.pid}")
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Server pid={self.process.pid} did not terminate in time, killing")
|
||||
self.process.kill()
|
||||
self.process.wait(timeout=5)
|
||||
except Exception as e:
|
||||
print(f"Error waiting for server: {e}")
|
||||
self.process = None
|
||||
if hasattr(self, '_log') and self._log != sys.stdout:
|
||||
self._log.close()
|
||||
|
||||
def make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | Any | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ServerResponse:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
parse_body = False
|
||||
if method == "GET":
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
parse_body = True
|
||||
elif method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data, timeout=timeout)
|
||||
parse_body = True
|
||||
elif method == "OPTIONS":
|
||||
response = requests.options(url, headers=headers, timeout=timeout)
|
||||
else:
|
||||
raise ValueError(f"Unimplemented method: {method}")
|
||||
result = ServerResponse()
|
||||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
if parse_body:
|
||||
try:
|
||||
result.body = response.json()
|
||||
except JSONDecodeError:
|
||||
result.body = response.text
|
||||
else:
|
||||
result.body = None
|
||||
print("Response from server", json.dumps(result.body, indent=2))
|
||||
return result
|
||||
|
||||
def make_stream_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
) -> Iterator[dict]:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
if method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
else:
|
||||
raise ValueError(f"Unimplemented method: {method}")
|
||||
if response.status_code != 200:
|
||||
raise ServerError(response.status_code, response.json())
|
||||
for line_bytes in response.iter_lines():
|
||||
line = line_bytes.decode("utf-8")
|
||||
if '[DONE]' in line:
|
||||
break
|
||||
elif line.startswith('data: '):
|
||||
data = json.loads(line[6:])
|
||||
print("Partial response from server", json.dumps(data, indent=2))
|
||||
yield data
|
||||
|
||||
def make_any_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict:
|
||||
stream = data.get('stream', False)
|
||||
if stream:
|
||||
content: list[str] = []
|
||||
reasoning_content: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
content_parts = 0
|
||||
reasoning_content_parts = 0
|
||||
tool_call_parts = 0
|
||||
arguments_parts = 0
|
||||
|
||||
for chunk in self.make_stream_request(method, path, data, headers):
|
||||
if chunk['choices']:
|
||||
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||
choice = chunk['choices'][0]
|
||||
if choice['delta'].get('content') is not None:
|
||||
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||
content.append(choice['delta']['content'])
|
||||
content_parts += 1
|
||||
if choice['delta'].get('reasoning_content') is not None:
|
||||
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
|
||||
reasoning_content.append(choice['delta']['reasoning_content'])
|
||||
reasoning_content_parts += 1
|
||||
if choice['delta'].get('finish_reason') is not None:
|
||||
finish_reason = choice['delta']['finish_reason']
|
||||
for tc in choice['delta'].get('tool_calls', []):
|
||||
if 'function' not in tc:
|
||||
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||
if tc['index'] >= len(tool_calls):
|
||||
assert 'id' in tc
|
||||
assert tc.get('type') == 'function'
|
||||
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
|
||||
f"Expected function call with name, got {tc.get('function')}"
|
||||
tool_calls.append(dict(
|
||||
id="",
|
||||
type="function",
|
||||
function=dict(
|
||||
name="",
|
||||
arguments="",
|
||||
)
|
||||
))
|
||||
tool_call = tool_calls[tc['index']]
|
||||
if tc.get('id') is not None:
|
||||
tool_call['id'] = tc['id']
|
||||
fct = tc['function']
|
||||
assert 'id' not in fct, f"Function call should not have id: {fct}"
|
||||
if fct.get('name') is not None:
|
||||
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
|
||||
if fct.get('arguments') is not None:
|
||||
tool_call['function']['arguments'] += fct['arguments']
|
||||
arguments_parts += 1
|
||||
tool_call_parts += 1
|
||||
else:
|
||||
# When `include_usage` is True (the default), we expect the last chunk of the stream
|
||||
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
|
||||
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
|
||||
# the last chunk)
|
||||
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
|
||||
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
|
||||
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
result = dict(
|
||||
choices=[
|
||||
dict(
|
||||
index=0,
|
||||
finish_reason=finish_reason,
|
||||
message=dict(
|
||||
role='assistant',
|
||||
content=''.join(content) if content else None,
|
||||
reasoning_content=''.join(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
print("Final response from server", json.dumps(result, indent=2))
|
||||
return result
|
||||
else:
|
||||
response = self.make_request(method, path, data, headers, timeout=timeout)
|
||||
assert response.status_code == 200, f"Server returned error: {response.status_code}"
|
||||
return response.body
|
||||
|
||||
|
||||
|
||||
server_instances: Set[ServerProcess] = set()
|
||||
|
||||
|
||||
class ServerPreset:
|
||||
@staticmethod
|
||||
def load_all() -> None:
|
||||
""" Load all server presets to ensure model files are cached. """
|
||||
servers: List[ServerProcess] = [
|
||||
method()
|
||||
for name, method in ServerPreset.__dict__.items()
|
||||
if callable(method) and name != "load_all"
|
||||
]
|
||||
for server in servers:
|
||||
server.offline = False
|
||||
server.start()
|
||||
server.stop()
|
||||
|
||||
@staticmethod
|
||||
def tinyllama2() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 64
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def bert_bge_small() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||
server.model_alias = "bert-bge-small"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 128
|
||||
server.n_ubatch = 128
|
||||
server.n_slots = 2
|
||||
server.seed = 42
|
||||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def bert_bge_small_with_fa() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||
server.model_alias = "bert-bge-small"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 300
|
||||
server.n_ubatch = 300
|
||||
server.n_slots = 2
|
||||
server.fa = "on"
|
||||
server.seed = 42
|
||||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinyllama_infill() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-infill"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
server.n_slots = 1
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def stories15m_moe() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/stories15M_MOE"
|
||||
server.model_hf_file = "stories15M_MOE-F16.gguf"
|
||||
server.model_alias = "stories15m-moe"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
server.n_slots = 1
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def jina_reranker_tiny() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
|
||||
server.model_alias = "jina-reranker"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 512
|
||||
server.n_slots = 1
|
||||
server.seed = 42
|
||||
server.server_reranking = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinygemma3() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
# mmproj is already provided by HF registry API
|
||||
server.model_hf_file = None
|
||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
server.model_alias = "tinygemma3"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 4
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def router() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
# router server has no models
|
||||
server.model_file = None
|
||||
server.model_alias = None
|
||||
server.model_hf_repo = None
|
||||
server.model_hf_file = None
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 16
|
||||
server.n_slots = 1
|
||||
server.n_predict = 16
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
|
||||
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||
"""
|
||||
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
|
||||
|
||||
Example usage:
|
||||
|
||||
results = parallel_function_calls([
|
||||
(func1, (arg1, arg2)),
|
||||
(func2, (arg3, arg4)),
|
||||
])
|
||||
"""
|
||||
results = [None] * len(function_list)
|
||||
exceptions = []
|
||||
|
||||
def worker(index, func, args):
|
||||
try:
|
||||
result = func(*args)
|
||||
results[index] = result
|
||||
except Exception as e:
|
||||
exceptions.append((index, str(e)))
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i, (func, args) in enumerate(function_list):
|
||||
future = executor.submit(worker, i, func, args)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in as_completed(futures):
|
||||
pass
|
||||
|
||||
# Check if there were any exceptions
|
||||
if exceptions:
|
||||
print("Exceptions occurred:")
|
||||
for index, error in exceptions:
|
||||
print(f"Function at index {index}: {error}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def match_regex(regex: str, text: str) -> bool:
|
||||
return (
|
||||
re.compile(
|
||||
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
|
||||
).search(text)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def download_file(url: str, output_file_path: str | None = None) -> str:
|
||||
"""
|
||||
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
|
||||
|
||||
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
|
||||
|
||||
Returns the local path of the downloaded file.
|
||||
"""
|
||||
file_name = url.split('/').pop()
|
||||
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
|
||||
if not os.path.exists(output_file):
|
||||
print(f"Downloading {url} to {output_file}")
|
||||
wget.download(url, out=output_file)
|
||||
print(f"Done downloading to {output_file}")
|
||||
else:
|
||||
print(f"File already exists at {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
def is_slow_test_allowed():
|
||||
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
||||
28
tools/server/webui/.gitignore
vendored
Normal file
28
tools/server/webui/.gitignore
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
test-results
|
||||
node_modules
|
||||
|
||||
# Output
|
||||
.output
|
||||
.vercel
|
||||
.netlify
|
||||
.wrangler
|
||||
/.svelte-kit
|
||||
/build
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Env
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
!.env.test
|
||||
|
||||
# Vite
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
*.code-workspace
|
||||
1
tools/server/webui/.npmrc
Normal file
1
tools/server/webui/.npmrc
Normal file
@@ -0,0 +1 @@
|
||||
engine-strict=true
|
||||
9
tools/server/webui/.prettierignore
Normal file
9
tools/server/webui/.prettierignore
Normal file
@@ -0,0 +1,9 @@
|
||||
# Package Managers
|
||||
package-lock.json
|
||||
pnpm-lock.yaml
|
||||
yarn.lock
|
||||
bun.lock
|
||||
bun.lockb
|
||||
|
||||
# Miscellaneous
|
||||
/static/
|
||||
16
tools/server/webui/.prettierrc
Normal file
16
tools/server/webui/.prettierrc
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"useTabs": true,
|
||||
"singleQuote": true,
|
||||
"trailingComma": "none",
|
||||
"printWidth": 100,
|
||||
"plugins": ["prettier-plugin-svelte", "prettier-plugin-tailwindcss"],
|
||||
"overrides": [
|
||||
{
|
||||
"files": "*.svelte",
|
||||
"options": {
|
||||
"parser": "svelte"
|
||||
}
|
||||
}
|
||||
],
|
||||
"tailwindStylesheet": "./src/app.css"
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
<script lang="ts">
|
||||
import { ModeWatcher } from 'mode-watcher';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children?: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
|
||||
onMount(() => {
|
||||
const root = document.documentElement;
|
||||
const theme = localStorage.getItem('mode-watcher-mode') || 'system';
|
||||
|
||||
if (theme === 'dark') {
|
||||
root.classList.add('dark');
|
||||
} else if (theme === 'light') {
|
||||
root.classList.remove('dark');
|
||||
} else {
|
||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||
if (prefersDark) {
|
||||
root.classList.add('dark');
|
||||
} else {
|
||||
root.classList.remove('dark');
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<ModeWatcher />
|
||||
|
||||
{#if children}
|
||||
{@const Component = children}
|
||||
|
||||
<Component />
|
||||
{/if}
|
||||
@@ -0,0 +1,13 @@
|
||||
<script lang="ts">
|
||||
import * as Tooltip from '../../src/lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
children: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Provider>
|
||||
{@render children()}
|
||||
</Tooltip.Provider>
|
||||
24
tools/server/webui/.storybook/main.ts
Normal file
24
tools/server/webui/.storybook/main.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import type { StorybookConfig } from '@storybook/sveltekit';
|
||||
import { dirname, resolve } from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'],
|
||||
addons: [
|
||||
'@storybook/addon-svelte-csf',
|
||||
'@chromatic-com/storybook',
|
||||
'@storybook/addon-vitest',
|
||||
'@storybook/addon-a11y',
|
||||
'@storybook/addon-docs'
|
||||
],
|
||||
framework: '@storybook/sveltekit',
|
||||
viteFinal: async (config) => {
|
||||
config.server = config.server || {};
|
||||
config.server.fs = config.server.fs || {};
|
||||
config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')];
|
||||
return config;
|
||||
}
|
||||
};
|
||||
export default config;
|
||||
42
tools/server/webui/.storybook/preview.ts
Normal file
42
tools/server/webui/.storybook/preview.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import type { Preview } from '@storybook/sveltekit';
|
||||
import '../src/app.css';
|
||||
import ModeWatcherDecorator from './decorators/ModeWatcherDecorator.svelte';
|
||||
import TooltipProviderDecorator from './decorators/TooltipProviderDecorator.svelte';
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/i
|
||||
}
|
||||
},
|
||||
|
||||
backgrounds: {
|
||||
disabled: true
|
||||
},
|
||||
|
||||
a11y: {
|
||||
// 'todo' - show a11y violations in the test UI only
|
||||
// 'error' - fail CI on a11y violations
|
||||
// 'off' - skip a11y checks entirely
|
||||
test: 'todo'
|
||||
}
|
||||
},
|
||||
decorators: [
|
||||
(story) => ({
|
||||
Component: ModeWatcherDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
}),
|
||||
(story) => ({
|
||||
Component: TooltipProviderDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
})
|
||||
]
|
||||
};
|
||||
|
||||
export default preview;
|
||||
12
tools/server/webui/.storybook/vitest.setup.ts
Normal file
12
tools/server/webui/.storybook/vitest.setup.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import * as a11yAddonAnnotations from '@storybook/addon-a11y/preview';
|
||||
import { setProjectAnnotations } from '@storybook/sveltekit';
|
||||
import * as previewAnnotations from './preview';
|
||||
import { beforeAll } from 'vitest';
|
||||
|
||||
const project = setProjectAnnotations([a11yAddonAnnotations, previewAnnotations]);
|
||||
|
||||
beforeAll(async () => {
|
||||
if (project.beforeAll) {
|
||||
await project.beforeAll();
|
||||
}
|
||||
});
|
||||
687
tools/server/webui/README.md
Normal file
687
tools/server/webui/README.md
Normal file
@@ -0,0 +1,687 @@
|
||||
# llama.cpp Web UI
|
||||
|
||||
A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities.
|
||||
|
||||
The WebUI supports two server operation modes:
|
||||
|
||||
- **MODEL mode** - Single model operation (standard llama-server)
|
||||
- **ROUTER mode** - Multi-model operation with dynamic model loading/unloading
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Tech Stack](#tech-stack)
|
||||
- [Build Pipeline](#build-pipeline)
|
||||
- [Architecture](#architecture)
|
||||
- [Data Flows](#data-flows)
|
||||
- [Architectural Patterns](#architectural-patterns)
|
||||
- [Testing](#testing)
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
### Chat Interface
|
||||
|
||||
- **Streaming responses** with real-time updates
|
||||
- **Reasoning content** - Support for models with thinking/reasoning blocks
|
||||
- **Dark/light theme** with system preference detection
|
||||
- **Responsive design** for desktop and mobile
|
||||
|
||||
### File Attachments
|
||||
|
||||
- **Images** - JPEG, PNG, GIF, WebP, SVG (with PNG conversion)
|
||||
- **Documents** - PDF (text extraction or image conversion for vision models)
|
||||
- **Audio** - MP3, WAV for audio-capable models
|
||||
- **Text files** - Source code, markdown, and other text formats
|
||||
- **Drag-and-drop** and paste support with rich previews
|
||||
|
||||
### Conversation Management
|
||||
|
||||
- **Branching** - Branch messages conversations at any point by editing messages or regenerating responses, navigate between branches
|
||||
- **Regeneration** - Regenerate responses with optional model switching (ROUTER mode)
|
||||
- **Import/Export** - JSON format for backup and sharing
|
||||
- **Search** - Find conversations by title or content
|
||||
|
||||
### Advanced Rendering
|
||||
|
||||
- **Syntax highlighting** - Code blocks with language detection
|
||||
- **Math formulas** - KaTeX rendering for LaTeX expressions
|
||||
- **Markdown** - Full GFM support with tables, lists, and more
|
||||
|
||||
### Multi-Model Support (ROUTER mode)
|
||||
|
||||
- **Model selector** with Loaded/Available groups
|
||||
- **Automatic loading** - Models load on selection
|
||||
- **Modality validation** - Prevents sending images to non-vision models
|
||||
- **LRU unloading** - Server auto-manages model cache
|
||||
|
||||
### Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Action |
|
||||
| ------------------ | -------------------- |
|
||||
| `Shift+Ctrl/Cmd+O` | New chat |
|
||||
| `Shift+Ctrl/Cmd+E` | Edit conversation |
|
||||
| `Shift+Ctrl/Cmd+D` | Delete conversation |
|
||||
| `Ctrl/Cmd+K` | Search conversations |
|
||||
| `Ctrl/Cmd+B` | Toggle sidebar |
|
||||
|
||||
### Developer Experience
|
||||
|
||||
- **Request tracking** - Monitor token generation with `/slots` endpoint
|
||||
- **Storybook** - Component library with visual testing
|
||||
- **Hot reload** - Instant updates during development
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **Node.js** 18+ (20+ recommended)
|
||||
- **npm** 9+
|
||||
- **llama-server** running locally (for API access)
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
cd tools/server/webui
|
||||
npm install
|
||||
```
|
||||
|
||||
### 2. Start llama-server
|
||||
|
||||
In a separate terminal, start the backend server:
|
||||
|
||||
```bash
|
||||
# Single model (MODEL mode)
|
||||
./llama-server -m model.gguf
|
||||
|
||||
# Multi-model (ROUTER mode)
|
||||
./llama-server --models-dir /path/to/models
|
||||
```
|
||||
|
||||
### 3. Start Development Servers
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
This starts:
|
||||
|
||||
- **Vite dev server** at `http://localhost:5173` - The main WebUI
|
||||
- **Storybook** at `http://localhost:6006` - Component documentation
|
||||
|
||||
The Vite dev server proxies API requests to `http://localhost:8080` (default llama-server port):
|
||||
|
||||
```typescript
|
||||
// vite.config.ts proxy configuration
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:8080',
|
||||
'/props': 'http://localhost:8080',
|
||||
'/slots': 'http://localhost:8080',
|
||||
'/models': 'http://localhost:8080'
|
||||
}
|
||||
```
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. Open `http://localhost:5173` in your browser
|
||||
2. Make changes to `.svelte`, `.ts`, or `.css` files
|
||||
3. Changes hot-reload instantly
|
||||
4. Use Storybook at `http://localhost:6006` for isolated component development
|
||||
|
||||
---
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Layer | Technology | Purpose |
|
||||
| ----------------- | ------------------------------- | -------------------------------------------------------- |
|
||||
| **Framework** | SvelteKit + Svelte 5 | Reactive UI with runes (`$state`, `$derived`, `$effect`) |
|
||||
| **UI Components** | shadcn-svelte + bits-ui | Accessible, customizable component library |
|
||||
| **Styling** | TailwindCSS 4 | Utility-first CSS with design tokens |
|
||||
| **Database** | IndexedDB (Dexie) | Client-side storage for conversations and messages |
|
||||
| **Build** | Vite | Fast bundling with static adapter |
|
||||
| **Testing** | Playwright + Vitest + Storybook | E2E, unit, and visual testing |
|
||||
| **Markdown** | remark + rehype | Markdown processing with KaTeX and syntax highlighting |
|
||||
|
||||
### Key Dependencies
|
||||
|
||||
```json
|
||||
{
|
||||
"svelte": "^5.0.0",
|
||||
"bits-ui": "^2.8.11",
|
||||
"dexie": "^4.0.11",
|
||||
"pdfjs-dist": "^5.4.54",
|
||||
"highlight.js": "^11.11.1",
|
||||
"rehype-katex": "^7.0.1"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Build Pipeline
|
||||
|
||||
### Development Build
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Runs Vite in development mode with:
|
||||
|
||||
- Hot Module Replacement (HMR)
|
||||
- Source maps
|
||||
- Proxy to llama-server
|
||||
|
||||
### Production Build
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
The build process:
|
||||
|
||||
1. **Vite Build** - Bundles all TypeScript, Svelte, and CSS
|
||||
2. **Static Adapter** - Outputs to `../public` (llama-server's static file directory)
|
||||
3. **Post-Build Script** - Cleans up intermediate files
|
||||
4. **Custom Plugin** - Creates `index.html` with:
|
||||
- Inlined favicon as base64
|
||||
- GZIP compression (level 9)
|
||||
- Deterministic output (zeroed timestamps)
|
||||
|
||||
```text
|
||||
tools/server/webui/ → build → tools/server/public/
|
||||
├── src/ ├── index.html (served by llama-server)
|
||||
├── static/ └── (favicon inlined)
|
||||
└── ...
|
||||
```
|
||||
|
||||
### SvelteKit Configuration
|
||||
|
||||
```javascript
|
||||
// svelte.config.js
|
||||
adapter: adapter({
|
||||
pages: '../public', // Output directory
|
||||
assets: '../public', // Static assets
|
||||
fallback: 'index.html', // SPA fallback
|
||||
strict: true
|
||||
}),
|
||||
output: {
|
||||
bundleStrategy: 'inline' // Single-file bundle
|
||||
}
|
||||
```
|
||||
|
||||
### Integration with llama-server
|
||||
|
||||
The WebUI is embedded directly into the llama-server binary:
|
||||
|
||||
1. `npm run build` outputs `index.html` to `tools/server/public/`
|
||||
2. llama-server compiles this into the binary at build time
|
||||
3. When accessing `/`, llama-server serves the gzipped HTML
|
||||
4. All assets are inlined (CSS, JS, fonts, favicon)
|
||||
|
||||
This results in a **single portable binary** with the full WebUI included.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
The WebUI follows a layered architecture with unidirectional data flow:
|
||||
|
||||
```text
|
||||
Routes → Components → Hooks → Stores → Services → Storage/API
|
||||
```
|
||||
|
||||
### High-Level Architecture
|
||||
|
||||
See: [`docs/architecture/high-level-architecture-simplified.md`](docs/architecture/high-level-architecture-simplified.md)
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Routes["📍 Routes"]
|
||||
R1["/ (Welcome)"]
|
||||
R2["/chat/[id]"]
|
||||
RL["+layout.svelte"]
|
||||
end
|
||||
|
||||
subgraph Components["🧩 Components"]
|
||||
C_Sidebar["ChatSidebar"]
|
||||
C_Screen["ChatScreen"]
|
||||
C_Form["ChatForm"]
|
||||
C_Messages["ChatMessages"]
|
||||
C_ModelsSelector["ModelsSelector"]
|
||||
C_Settings["ChatSettings"]
|
||||
end
|
||||
|
||||
subgraph Stores["🗄️ Stores"]
|
||||
S1["chatStore"]
|
||||
S2["conversationsStore"]
|
||||
S3["modelsStore"]
|
||||
S4["serverStore"]
|
||||
S5["settingsStore"]
|
||||
end
|
||||
|
||||
subgraph Services["⚙️ Services"]
|
||||
SV1["ChatService"]
|
||||
SV2["ModelsService"]
|
||||
SV3["PropsService"]
|
||||
SV4["DatabaseService"]
|
||||
end
|
||||
|
||||
subgraph Storage["💾 Storage"]
|
||||
ST1["IndexedDB"]
|
||||
ST2["LocalStorage"]
|
||||
end
|
||||
|
||||
subgraph APIs["🌐 llama-server"]
|
||||
API1["/v1/chat/completions"]
|
||||
API2["/props"]
|
||||
API3["/models/*"]
|
||||
end
|
||||
|
||||
R1 & R2 --> C_Screen
|
||||
RL --> C_Sidebar
|
||||
C_Screen --> C_Form & C_Messages & C_Settings
|
||||
C_Screen --> S1 & S2
|
||||
C_ModelsSelector --> S3 & S4
|
||||
S1 --> SV1 & SV4
|
||||
S3 --> SV2 & SV3
|
||||
SV4 --> ST1
|
||||
SV1 --> API1
|
||||
SV2 --> API3
|
||||
SV3 --> API2
|
||||
```
|
||||
|
||||
### Layer Breakdown
|
||||
|
||||
#### Routes (`src/routes/`)
|
||||
|
||||
- **`/`** - Welcome screen, creates new conversation
|
||||
- **`/chat/[id]`** - Active chat interface
|
||||
- **`+layout.svelte`** - Sidebar, navigation, global initialization
|
||||
|
||||
#### Components (`src/lib/components/`)
|
||||
|
||||
Components are organized in `app/` (application-specific) and `ui/` (shadcn-svelte primitives).
|
||||
|
||||
**Chat Components** (`app/chat/`):
|
||||
|
||||
| Component | Responsibility |
|
||||
| ------------------ | --------------------------------------------------------------------------- |
|
||||
| `ChatScreen/` | Main chat container, coordinates message list, input form, and attachments |
|
||||
| `ChatForm/` | Message input textarea with file upload, paste handling, keyboard shortcuts |
|
||||
| `ChatMessages/` | Message list with branch navigation, regenerate/continue/edit actions |
|
||||
| `ChatAttachments/` | File attachment previews, drag-and-drop, PDF/image/audio handling |
|
||||
| `ChatSettings/` | Parameter sliders (temperature, top-p, etc.) with server default sync |
|
||||
| `ChatSidebar/` | Conversation list, search, import/export, navigation |
|
||||
|
||||
**Dialog Components** (`app/dialogs/`):
|
||||
|
||||
| Component | Responsibility |
|
||||
| ------------------------------- | -------------------------------------------------------- |
|
||||
| `DialogChatSettings` | Full-screen settings configuration |
|
||||
| `DialogModelInformation` | Model details (context size, modalities, parallel slots) |
|
||||
| `DialogChatAttachmentPreview` | Full preview for images, PDFs (text or page view), code |
|
||||
| `DialogConfirmation` | Generic confirmation for destructive actions |
|
||||
| `DialogConversationTitleUpdate` | Edit conversation title |
|
||||
|
||||
**Server/Model Components** (`app/server/`, `app/models/`):
|
||||
|
||||
| Component | Responsibility |
|
||||
| ------------------- | --------------------------------------------------------- |
|
||||
| `ServerErrorSplash` | Error display when server is unreachable |
|
||||
| `ModelsSelector` | Model dropdown with Loaded/Available groups (ROUTER mode) |
|
||||
|
||||
**Shared UI Components** (`app/misc/`):
|
||||
|
||||
| Component | Responsibility |
|
||||
| -------------------------------- | ---------------------------------------------------------------- |
|
||||
| `MarkdownContent` | Markdown rendering with KaTeX, syntax highlighting, copy buttons |
|
||||
| `SyntaxHighlightedCode` | Code blocks with language detection and highlighting |
|
||||
| `ActionButton`, `ActionDropdown` | Reusable action buttons and menus |
|
||||
| `BadgeModality`, `BadgeInfo` | Status and capability badges |
|
||||
|
||||
#### Hooks (`src/lib/hooks/`)
|
||||
|
||||
- **`useModelChangeValidation`** - Validates model switch against conversation modalities
|
||||
- **`useProcessingState`** - Tracks streaming progress and token generation
|
||||
|
||||
#### Stores (`src/lib/stores/`)
|
||||
|
||||
| Store | Responsibility |
|
||||
| -------------------- | --------------------------------------------------------- |
|
||||
| `chatStore` | Message sending, streaming, abort control, error handling |
|
||||
| `conversationsStore` | CRUD for conversations, message branching, navigation |
|
||||
| `modelsStore` | Model list, selection, loading/unloading (ROUTER) |
|
||||
| `serverStore` | Server properties, role detection, modalities |
|
||||
| `settingsStore` | User preferences, parameter sync with server defaults |
|
||||
|
||||
#### Services (`src/lib/services/`)
|
||||
|
||||
| Service | Responsibility |
|
||||
| ---------------------- | ----------------------------------------------- |
|
||||
| `ChatService` | API calls to`/v1/chat/completions`, SSE parsing |
|
||||
| `ModelsService` | `/models`, `/models/load`, `/models/unload` |
|
||||
| `PropsService` | `/props`, `/props?model=` |
|
||||
| `DatabaseService` | IndexedDB operations via Dexie |
|
||||
| `ParameterSyncService` | Syncs settings with server defaults |
|
||||
|
||||
---
|
||||
|
||||
## Data Flows
|
||||
|
||||
### MODEL Mode (Single Model)
|
||||
|
||||
See: [`docs/flows/data-flow-simplified-model-mode.md`](docs/flows/data-flow-simplified-model-mode.md)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant UI
|
||||
participant Stores
|
||||
participant DB as IndexedDB
|
||||
participant API as llama-server
|
||||
|
||||
Note over User,API: Initialization
|
||||
UI->>Stores: initialize()
|
||||
Stores->>DB: load conversations
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: server config
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: single model (auto-selected)
|
||||
|
||||
Note over User,API: Chat Flow
|
||||
User->>UI: send message
|
||||
Stores->>DB: save user message
|
||||
Stores->>API: POST /v1/chat/completions (stream)
|
||||
loop streaming
|
||||
API-->>Stores: SSE chunks
|
||||
Stores-->>UI: reactive update
|
||||
end
|
||||
Stores->>DB: save assistant message
|
||||
```
|
||||
|
||||
### ROUTER Mode (Multi-Model)
|
||||
|
||||
See: [`docs/flows/data-flow-simplified-router-mode.md`](docs/flows/data-flow-simplified-router-mode.md)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant UI
|
||||
participant Stores
|
||||
participant API as llama-server
|
||||
|
||||
Note over User,API: Initialization
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: {role: "router"}
|
||||
Stores->>API: GET /models
|
||||
API-->>Stores: models[] with status
|
||||
|
||||
Note over User,API: Model Selection
|
||||
User->>UI: select model
|
||||
alt model not loaded
|
||||
Stores->>API: POST /models/load
|
||||
loop poll status
|
||||
Stores->>API: GET /models
|
||||
end
|
||||
Stores->>API: GET /props?model=X
|
||||
end
|
||||
Stores->>Stores: validate modalities
|
||||
|
||||
Note over User,API: Chat Flow
|
||||
Stores->>API: POST /v1/chat/completions {model: X}
|
||||
loop streaming
|
||||
API-->>Stores: SSE chunks + model info
|
||||
end
|
||||
```
|
||||
|
||||
### Detailed Flow Diagrams
|
||||
|
||||
| Flow | Description | File |
|
||||
| ------------- | ------------------------------------------ | ----------------------------------------------------------- |
|
||||
| Chat | Message lifecycle, streaming, regeneration | [`chat-flow.md`](docs/flows/chat-flow.md) |
|
||||
| Models | Loading, unloading, modality caching | [`models-flow.md`](docs/flows/models-flow.md) |
|
||||
| Server | Props fetching, role detection | [`server-flow.md`](docs/flows/server-flow.md) |
|
||||
| Conversations | CRUD, branching, import/export | [`conversations-flow.md`](docs/flows/conversations-flow.md) |
|
||||
| Database | IndexedDB schema, operations | [`database-flow.md`](docs/flows/database-flow.md) |
|
||||
| Settings | Parameter sync, user overrides | [`settings-flow.md`](docs/flows/settings-flow.md) |
|
||||
|
||||
---
|
||||
|
||||
## Architectural Patterns
|
||||
|
||||
### 1. Reactive State with Svelte 5 Runes
|
||||
|
||||
All stores use Svelte 5's fine-grained reactivity:
|
||||
|
||||
```typescript
|
||||
// Store with reactive state
|
||||
class ChatStore {
|
||||
#isLoading = $state(false);
|
||||
#currentResponse = $state('');
|
||||
|
||||
// Derived values auto-update
|
||||
get isStreaming() {
|
||||
return $derived(this.#isLoading && this.#currentResponse.length > 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Exported reactive accessors
|
||||
export const isLoading = () => chatStore.isLoading;
|
||||
export const currentResponse = () => chatStore.currentResponse;
|
||||
```
|
||||
|
||||
### 2. Unidirectional Data Flow
|
||||
|
||||
Data flows in one direction, making state predictable:
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph UI["UI Layer"]
|
||||
A[User Action] --> B[Component]
|
||||
end
|
||||
|
||||
subgraph State["State Layer"]
|
||||
B --> C[Store Method]
|
||||
C --> D[State Update]
|
||||
end
|
||||
|
||||
subgraph IO["I/O Layer"]
|
||||
C --> E[Service]
|
||||
E --> F[API / IndexedDB]
|
||||
F -.->|Response| D
|
||||
end
|
||||
|
||||
D -->|Reactive| B
|
||||
```
|
||||
|
||||
Components dispatch actions to stores, stores coordinate with services for I/O, and state updates reactively propagate back to the UI.
|
||||
|
||||
### 3. Per-Conversation State
|
||||
|
||||
Enables concurrent streaming across multiple conversations:
|
||||
|
||||
```typescript
|
||||
class ChatStore {
|
||||
chatLoadingStates = new Map<string, boolean>();
|
||||
chatStreamingStates = new Map<string, { response: string; messageId: string }>();
|
||||
abortControllers = new Map<string, AbortController>();
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Message Branching with Tree Structure
|
||||
|
||||
Conversations are stored as a tree, not a linear list:
|
||||
|
||||
```typescript
|
||||
interface DatabaseMessage {
|
||||
id: string;
|
||||
parent: string | null; // Points to parent message
|
||||
children: string[]; // List of child message IDs
|
||||
// ...
|
||||
}
|
||||
|
||||
interface DatabaseConversation {
|
||||
currentNode: string; // Currently viewed branch tip
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
Navigation between branches updates `currentNode` without losing history.
|
||||
|
||||
### 5. Layered Service Architecture
|
||||
|
||||
Stores handle state; services handle I/O:
|
||||
|
||||
```text
|
||||
┌─────────────────┐
|
||||
│ Stores │ Business logic, state management
|
||||
├─────────────────┤
|
||||
│ Services │ API calls, database operations
|
||||
├─────────────────┤
|
||||
│ Storage/API │ IndexedDB, LocalStorage, HTTP
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
### 6. Server Role Abstraction
|
||||
|
||||
Single codebase handles both MODEL and ROUTER modes:
|
||||
|
||||
```typescript
|
||||
// serverStore.ts
|
||||
get isRouterMode() {
|
||||
return this.role === ServerRole.ROUTER;
|
||||
}
|
||||
|
||||
// Components conditionally render based on mode
|
||||
{#if isRouterMode()}
|
||||
<ModelsSelector />
|
||||
{/if}
|
||||
```
|
||||
|
||||
### 7. Modality Validation
|
||||
|
||||
Prevents sending attachments to incompatible models:
|
||||
|
||||
```typescript
|
||||
// useModelChangeValidation hook
|
||||
const validate = (modelId: string) => {
|
||||
const modelModalities = modelsStore.getModelModalities(modelId);
|
||||
const conversationModalities = conversationsStore.usedModalities;
|
||||
|
||||
// Check if model supports all used modalities
|
||||
if (conversationModalities.hasImages && !modelModalities.vision) {
|
||||
return { valid: false, reason: 'Model does not support images' };
|
||||
}
|
||||
// ...
|
||||
};
|
||||
```
|
||||
|
||||
### 8. Persistent Storage Strategy
|
||||
|
||||
Data is persisted across sessions using two storage mechanisms:
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Browser["Browser Storage"]
|
||||
subgraph IDB["IndexedDB (Dexie)"]
|
||||
C[Conversations]
|
||||
M[Messages]
|
||||
end
|
||||
subgraph LS["LocalStorage"]
|
||||
S[Settings Config]
|
||||
O[User Overrides]
|
||||
T[Theme Preference]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Stores["Svelte Stores"]
|
||||
CS[conversationsStore] --> C
|
||||
CS --> M
|
||||
SS[settingsStore] --> S
|
||||
SS --> O
|
||||
SS --> T
|
||||
end
|
||||
```
|
||||
|
||||
- **IndexedDB**: Conversations and messages (large, structured data)
|
||||
- **LocalStorage**: Settings, user parameter overrides, theme (small key-value data)
|
||||
- **Memory only**: Server props, model list (fetched fresh on each session)
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
### Test Types
|
||||
|
||||
| Type | Tool | Location | Command |
|
||||
| ------------- | ------------------ | ---------------- | ------------------- |
|
||||
| **Unit** | Vitest | `tests/unit/` | `npm run test:unit` |
|
||||
| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` |
|
||||
| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` |
|
||||
| **Client** | Vitest | `tests/client/`. | `npm run test:unit` |
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# All tests
|
||||
npm run test
|
||||
|
||||
# Individual test suites
|
||||
npm run test:e2e # End-to-end (requires llama-server)
|
||||
npm run test:client # Client-side unit tests
|
||||
npm run test:server # Server-side unit tests
|
||||
npm run test:ui # Storybook visual tests
|
||||
```
|
||||
|
||||
### Storybook Development
|
||||
|
||||
```bash
|
||||
npm run storybook # Start Storybook dev server on :6006
|
||||
npm run build-storybook # Build static Storybook
|
||||
```
|
||||
|
||||
### Linting and Formatting
|
||||
|
||||
```bash
|
||||
npm run lint # Check code style
|
||||
npm run format # Auto-format with Prettier
|
||||
npm run check # TypeScript type checking
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
tools/server/webui/
|
||||
├── src/
|
||||
│ ├── lib/
|
||||
│ │ ├── components/ # UI components (app/, ui/)
|
||||
│ │ ├── hooks/ # Svelte hooks
|
||||
│ │ ├── stores/ # State management
|
||||
│ │ ├── services/ # API and database services
|
||||
│ │ ├── types/ # TypeScript interfaces
|
||||
│ │ └── utils/ # Utility functions
|
||||
│ ├── routes/ # SvelteKit routes
|
||||
│ └── styles/ # Global styles
|
||||
├── static/ # Static assets
|
||||
├── tests/ # Test files
|
||||
├── docs/ # Architecture diagrams
|
||||
│ ├── architecture/ # High-level architecture
|
||||
│ └── flows/ # Feature-specific flows
|
||||
└── .storybook/ # Storybook configuration
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [llama.cpp Server README](../README.md) - Full server documentation
|
||||
- [Multimodal Documentation](../../../docs/multimodal.md) - Image and audio support
|
||||
- [Function Calling](../../../docs/function-calling.md) - Tool use capabilities
|
||||
16
tools/server/webui/components.json
Normal file
16
tools/server/webui/components.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"$schema": "https://shadcn-svelte.com/schema.json",
|
||||
"tailwind": {
|
||||
"css": "src/app.css",
|
||||
"baseColor": "neutral"
|
||||
},
|
||||
"aliases": {
|
||||
"components": "$lib/components",
|
||||
"utils": "$lib/components/ui/utils",
|
||||
"ui": "$lib/components/ui",
|
||||
"hooks": "$lib/hooks",
|
||||
"lib": "$lib"
|
||||
},
|
||||
"typescript": true,
|
||||
"registry": "https://shadcn-svelte.com/registry"
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Routes["📍 Routes"]
|
||||
R1["/ (Welcome)"]
|
||||
R2["/chat/[id]"]
|
||||
RL["+layout.svelte"]
|
||||
end
|
||||
|
||||
subgraph Components["🧩 Components"]
|
||||
C_Sidebar["ChatSidebar"]
|
||||
C_Screen["ChatScreen"]
|
||||
C_Form["ChatForm"]
|
||||
C_Messages["ChatMessages"]
|
||||
C_Message["ChatMessage"]
|
||||
C_ChatMessageAgenticContent["ChatMessageAgenticContent"]
|
||||
C_MessageEditForm["ChatMessageEditForm"]
|
||||
C_ModelsSelector["ModelsSelector"]
|
||||
C_Settings["ChatSettings"]
|
||||
C_McpSettings["McpServersSettings"]
|
||||
C_McpResourceBrowser["McpResourceBrowser"]
|
||||
C_McpServersSelector["McpServersSelector"]
|
||||
end
|
||||
|
||||
subgraph Hooks["🪝 Hooks"]
|
||||
H1["useModelChangeValidation"]
|
||||
H2["useProcessingState"]
|
||||
end
|
||||
|
||||
subgraph Stores["🗄️ Stores"]
|
||||
S1["chatStore<br/><i>Chat interactions & streaming</i>"]
|
||||
SA["agenticStore<br/><i>Multi-turn agentic loop orchestration</i>"]
|
||||
S2["conversationsStore<br/><i>Conversation data, messages & MCP overrides</i>"]
|
||||
S3["modelsStore<br/><i>Model selection & loading</i>"]
|
||||
S4["serverStore<br/><i>Server props & role detection</i>"]
|
||||
S5["settingsStore<br/><i>User configuration incl. MCP</i>"]
|
||||
S6["mcpStore<br/><i>MCP servers, tools, prompts</i>"]
|
||||
S7["mcpResourceStore<br/><i>MCP resources & attachments</i>"]
|
||||
end
|
||||
|
||||
subgraph Services["⚙️ Services"]
|
||||
SV1["ChatService"]
|
||||
SV2["ModelsService"]
|
||||
SV3["PropsService"]
|
||||
SV4["DatabaseService"]
|
||||
SV5["ParameterSyncService"]
|
||||
SV6["MCPService<br/><i>protocol operations</i>"]
|
||||
end
|
||||
|
||||
subgraph Storage["💾 Storage"]
|
||||
ST1["IndexedDB<br/><i>conversations, messages</i>"]
|
||||
ST2["LocalStorage<br/><i>config, userOverrides, mcpServers</i>"]
|
||||
end
|
||||
|
||||
subgraph APIs["🌐 llama-server API"]
|
||||
API1["/v1/chat/completions"]
|
||||
API2["/props"]
|
||||
API3["/models/*"]
|
||||
API4["/v1/models"]
|
||||
end
|
||||
|
||||
subgraph ExternalMCP["🔌 External MCP Servers"]
|
||||
EXT1["MCP Server 1<br/><i>WebSocket/HTTP/SSE</i>"]
|
||||
EXT2["MCP Server N"]
|
||||
end
|
||||
|
||||
%% Routes → Components
|
||||
R1 & R2 --> C_Screen
|
||||
RL --> C_Sidebar
|
||||
|
||||
%% Layout runs MCP health checks
|
||||
RL --> S6
|
||||
|
||||
%% Component hierarchy
|
||||
C_Screen --> C_Form & C_Messages & C_Settings
|
||||
C_Messages --> C_Message
|
||||
C_Message --> C_ChatMessageAgenticContent
|
||||
C_Message --> C_MessageEditForm
|
||||
C_Form & C_MessageEditForm --> C_ModelsSelector
|
||||
C_Form --> C_McpServersSelector
|
||||
C_Settings --> C_McpSettings
|
||||
C_McpSettings --> C_McpResourceBrowser
|
||||
|
||||
%% Components → Hooks → Stores
|
||||
C_Form & C_Messages --> H1 & H2
|
||||
H1 --> S3 & S4
|
||||
H2 --> S1 & S5
|
||||
|
||||
%% Components → Stores
|
||||
C_Screen --> S1 & S2
|
||||
C_Sidebar --> S2
|
||||
C_ModelsSelector --> S3 & S4
|
||||
C_Settings --> S5
|
||||
C_McpSettings --> S6
|
||||
C_McpResourceBrowser --> S6 & S7
|
||||
C_McpServersSelector --> S6
|
||||
C_Form --> S6
|
||||
|
||||
%% chatStore → agenticStore → mcpStore (agentic loop)
|
||||
S1 --> SA
|
||||
SA --> SV1
|
||||
SA --> S6
|
||||
|
||||
%% Stores → Services
|
||||
S1 --> SV1 & SV4
|
||||
S2 --> SV4
|
||||
S3 --> SV2 & SV3
|
||||
S4 --> SV3
|
||||
S5 --> SV5
|
||||
S6 --> SV6
|
||||
S7 --> SV6
|
||||
|
||||
%% Services → Storage
|
||||
SV4 --> ST1
|
||||
SV5 --> ST2
|
||||
|
||||
%% Services → APIs
|
||||
SV1 --> API1
|
||||
SV2 --> API3 & API4
|
||||
SV3 --> API2
|
||||
|
||||
%% MCP → External Servers
|
||||
SV6 --> EXT1 & EXT2
|
||||
|
||||
%% Styling
|
||||
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
|
||||
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
|
||||
classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px
|
||||
classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px
|
||||
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
|
||||
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
|
||||
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
|
||||
classDef mcpStyle fill:#e0f2f1,stroke:#00695c,stroke-width:2px
|
||||
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
|
||||
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
|
||||
|
||||
class R1,R2,RL routeStyle
|
||||
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_ChatMessageAgenticContent,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle
|
||||
class C_McpSettings,C_McpResourceBrowser,C_McpServersSelector componentStyle
|
||||
class H1,H2 hookStyle
|
||||
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
|
||||
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
|
||||
class ST1,ST2 storageStyle
|
||||
class API1,API2,API3,API4 apiStyle
|
||||
class EXT1,EXT2 externalStyle
|
||||
```
|
||||
373
tools/server/webui/docs/architecture/high-level-architecture.md
Normal file
373
tools/server/webui/docs/architecture/high-level-architecture.md
Normal file
@@ -0,0 +1,373 @@
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Routes["📍 Routes"]
|
||||
R1["/ (+page.svelte)"]
|
||||
R2["/chat/[id]"]
|
||||
RL["+layout.svelte"]
|
||||
end
|
||||
|
||||
subgraph Components["🧩 Components"]
|
||||
direction TB
|
||||
subgraph LayoutComponents["Layout"]
|
||||
C_Sidebar["ChatSidebar"]
|
||||
C_Screen["ChatScreen"]
|
||||
end
|
||||
subgraph ChatUIComponents["Chat UI"]
|
||||
C_Form["ChatForm"]
|
||||
C_Messages["ChatMessages"]
|
||||
C_Message["ChatMessage"]
|
||||
C_MessageUser["ChatMessageUser"]
|
||||
C_MessageEditForm["ChatMessageEditForm"]
|
||||
C_Attach["ChatAttachments"]
|
||||
C_ModelsSelector["ModelsSelector"]
|
||||
C_Settings["ChatSettings"]
|
||||
end
|
||||
subgraph MCPComponents["MCP UI"]
|
||||
C_McpSettings["McpServersSettings"]
|
||||
C_McpServerCard["McpServerCard"]
|
||||
C_McpResourceBrowser["McpResourceBrowser"]
|
||||
C_McpResourcePreview["McpResourcePreview"]
|
||||
C_McpServersSelector["McpServersSelector"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Hooks["🪝 Hooks"]
|
||||
H1["useModelChangeValidation"]
|
||||
H2["useProcessingState"]
|
||||
H3["isMobile"]
|
||||
end
|
||||
|
||||
subgraph Stores["🗄️ Stores"]
|
||||
direction TB
|
||||
subgraph S1["chatStore"]
|
||||
S1State["<b>State:</b><br/>isLoading, currentResponse<br/>errorDialogState<br/>activeProcessingState<br/>chatLoadingStates<br/>chatStreamingStates<br/>abortControllers<br/>processingStates<br/>activeConversationId<br/>isStreamingActive"]
|
||||
S1LoadState["<b>Loading State:</b><br/>setChatLoading()<br/>isChatLoading()<br/>syncLoadingStateForChat()<br/>clearUIState()<br/>isChatLoadingPublic()<br/>getAllLoadingChats()<br/>getAllStreamingChats()"]
|
||||
S1ProcState["<b>Processing State:</b><br/>setActiveProcessingConversation()<br/>getProcessingState()<br/>clearProcessingState()<br/>getActiveProcessingState()<br/>updateProcessingStateFromTimings()<br/>getCurrentProcessingStateSync()<br/>restoreProcessingStateFromMessages()"]
|
||||
S1Stream["<b>Streaming:</b><br/>streamChatCompletion()<br/>startStreaming()<br/>stopStreaming()<br/>stopGeneration()<br/>isStreaming()"]
|
||||
S1Error["<b>Error Handling:</b><br/>showErrorDialog()<br/>dismissErrorDialog()<br/>isAbortError()"]
|
||||
S1Msg["<b>Message Operations:</b><br/>addMessage()<br/>sendMessage()<br/>updateMessage()<br/>deleteMessage()<br/>getDeletionInfo()"]
|
||||
S1Regen["<b>Regeneration:</b><br/>regenerateMessage()<br/>regenerateMessageWithBranching()<br/>continueAssistantMessage()"]
|
||||
S1Edit["<b>Editing:</b><br/>editAssistantMessage()<br/>editUserMessagePreserveResponses()<br/>editMessageWithBranching()<br/>clearEditMode()<br/>isEditModeActive()<br/>getAddFilesHandler()<br/>setEditModeActive()"]
|
||||
S1Utils["<b>Utilities:</b><br/>getApiOptions()<br/>parseTimingData()<br/>getOrCreateAbortController()<br/>getConversationModel()"]
|
||||
end
|
||||
subgraph SA["agenticStore"]
|
||||
SAState["<b>State:</b><br/>sessions (Map)<br/>isAnyRunning"]
|
||||
SASession["<b>Session Management:</b><br/>getSession()<br/>updateSession()<br/>clearSession()<br/>getActiveSessions()<br/>isRunning()<br/>currentTurn()<br/>totalToolCalls()<br/>lastError()<br/>streamingToolCall()"]
|
||||
SAConfig["<b>Configuration:</b><br/>getConfig()<br/>maxTurns, maxToolPreviewLines"]
|
||||
SAFlow["<b>Agentic Loop:</b><br/>runAgenticFlow()<br/>executeAgenticLoop()<br/>normalizeToolCalls()<br/>emitToolCallResult()<br/>extractBase64Attachments()"]
|
||||
end
|
||||
subgraph S2["conversationsStore"]
|
||||
S2State["<b>State:</b><br/>conversations<br/>activeConversation<br/>activeMessages<br/>isInitialized<br/>pendingMcpServerOverrides<br/>titleUpdateConfirmationCallback"]
|
||||
S2Lifecycle["<b>Lifecycle:</b><br/>initialize()<br/>loadConversations()<br/>clearActiveConversation()"]
|
||||
S2ConvCRUD["<b>Conversation CRUD:</b><br/>createConversation()<br/>loadConversation()<br/>deleteConversation()<br/>deleteAll()<br/>updateConversationName()<br/>updateConversationTitleWithConfirmation()"]
|
||||
S2MsgMgmt["<b>Message Management:</b><br/>refreshActiveMessages()<br/>addMessageToActive()<br/>updateMessageAtIndex()<br/>findMessageIndex()<br/>sliceActiveMessages()<br/>removeMessageAtIndex()<br/>getConversationMessages()"]
|
||||
S2Nav["<b>Navigation:</b><br/>navigateToSibling()<br/>updateCurrentNode()<br/>updateConversationTimestamp()"]
|
||||
S2McpOverrides["<b>MCP Per-Chat Overrides:</b><br/>getMcpServerOverride()<br/>getAllMcpServerOverrides()<br/>setMcpServerOverride()<br/>toggleMcpServerForChat()<br/>removeMcpServerOverride()<br/>isMcpServerEnabledForChat()<br/>clearPendingMcpServerOverrides()"]
|
||||
S2Export["<b>Import/Export:</b><br/>downloadConversation()<br/>exportAllConversations()<br/>importConversations()<br/>importConversationsData()<br/>triggerDownload()"]
|
||||
S2Utils["<b>Utilities:</b><br/>setTitleUpdateConfirmationCallback()"]
|
||||
end
|
||||
subgraph S3["modelsStore"]
|
||||
S3State["<b>State:</b><br/>models, routerModels<br/>selectedModelId<br/>selectedModelName<br/>loading, updating, error<br/>modelLoadingStates<br/>modelPropsCache<br/>modelPropsFetching<br/>propsCacheVersion"]
|
||||
S3Getters["<b>Computed Getters:</b><br/>selectedModel<br/>loadedModelIds<br/>loadingModelIds<br/>singleModelName"]
|
||||
S3Modal["<b>Modalities:</b><br/>getModelModalities()<br/>modelSupportsVision()<br/>modelSupportsAudio()<br/>getModelModalitiesArray()<br/>getModelProps()<br/>updateModelModalities()"]
|
||||
S3Status["<b>Status Queries:</b><br/>isModelLoaded()<br/>isModelOperationInProgress()<br/>getModelStatus()<br/>isModelPropsFetching()"]
|
||||
S3Fetch["<b>Data Fetching:</b><br/>fetch()<br/>fetchRouterModels()<br/>fetchModelProps()<br/>fetchModalitiesForLoadedModels()"]
|
||||
S3Select["<b>Model Selection:</b><br/>selectModelById()<br/>selectModelByName()<br/>clearSelection()<br/>findModelByName()<br/>findModelById()<br/>hasModel()"]
|
||||
S3LoadUnload["<b>Loading/Unloading Models:</b><br/>loadModel()<br/>unloadModel()<br/>ensureModelLoaded()<br/>waitForModelStatus()<br/>pollForModelStatus()"]
|
||||
S3Utils["<b>Utilities:</b><br/>toDisplayName()<br/>clear()"]
|
||||
end
|
||||
subgraph S4["serverStore"]
|
||||
S4State["<b>State:</b><br/>props<br/>loading, error<br/>role<br/>fetchPromise"]
|
||||
S4Getters["<b>Getters:</b><br/>defaultParams<br/>contextSize<br/>isRouterMode<br/>isModelMode"]
|
||||
S4Data["<b>Data Handling:</b><br/>fetch()<br/>getErrorMessage()<br/>clear()"]
|
||||
S4Utils["<b>Utilities:</b><br/>detectRole()"]
|
||||
end
|
||||
subgraph S5["settingsStore"]
|
||||
S5State["<b>State:</b><br/>config<br/>theme<br/>isInitialized<br/>userOverrides"]
|
||||
S5Lifecycle["<b>Lifecycle:</b><br/>initialize()<br/>loadConfig()<br/>saveConfig()<br/>loadTheme()<br/>saveTheme()"]
|
||||
S5Update["<b>Config Updates:</b><br/>updateConfig()<br/>updateMultipleConfig()<br/>updateTheme()"]
|
||||
S5Reset["<b>Reset:</b><br/>resetConfig()<br/>resetTheme()<br/>resetAll()<br/>resetParameterToServerDefault()"]
|
||||
S5Sync["<b>Server Sync:</b><br/>syncWithServerDefaults()<br/>forceSyncWithServerDefaults()"]
|
||||
S5Utils["<b>Utilities:</b><br/>getConfig()<br/>getAllConfig()<br/>getParameterInfo()<br/>getParameterDiff()<br/>getServerDefaults()<br/>clearAllUserOverrides()"]
|
||||
end
|
||||
subgraph S6["mcpStore"]
|
||||
S6State["<b>State:</b><br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)"]
|
||||
S6Lifecycle["<b>Lifecycle:</b><br/>ensureInitialized()<br/>initialize()<br/>shutdown()<br/>acquireConnection()<br/>releaseConnection()"]
|
||||
S6Health["<b>Health Checks:</b><br/>runHealthCheck()<br/>runHealthChecksForServers()<br/>updateHealthCheck()<br/>getHealthCheckState()<br/>clearHealthCheck()"]
|
||||
S6Servers["<b>Server Management:</b><br/>getServers()<br/>addServer()<br/>updateServer()<br/>removeServer()<br/>getServerById()<br/>getServerDisplayName()"]
|
||||
S6Tools["<b>Tool Operations:</b><br/>getToolDefinitionsForLLM()<br/>getToolNames()<br/>hasTool()<br/>getToolServer()<br/>executeTool()<br/>executeToolByName()"]
|
||||
S6Prompts["<b>Prompt Operations:</b><br/>getAllPrompts()<br/>getPrompt()<br/>hasPromptsCapability()<br/>getPromptCompletions()"]
|
||||
end
|
||||
subgraph S7["mcpResourceStore"]
|
||||
S7State["<b>State:</b><br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]<br/>isLoading"]
|
||||
S7Resources["<b>Resource Discovery:</b><br/>setServerResources()<br/>getServerResources()<br/>getAllResourceInfos()<br/>getAllTemplateInfos()<br/>clearServerResources()"]
|
||||
S7Cache["<b>Caching:</b><br/>cacheResourceContent()<br/>getCachedContent()<br/>invalidateCache()<br/>clearCache()"]
|
||||
S7Subs["<b>Subscriptions:</b><br/>addSubscription()<br/>removeSubscription()<br/>isSubscribed()<br/>handleResourceUpdate()"]
|
||||
S7Attach["<b>Attachments:</b><br/>addAttachment()<br/>updateAttachmentContent()<br/>removeAttachment()<br/>clearAttachments()<br/>toMessageExtras()"]
|
||||
end
|
||||
|
||||
subgraph ReactiveExports["⚡ Reactive Exports"]
|
||||
direction LR
|
||||
subgraph ChatExports["chatStore"]
|
||||
RE1["isLoading()"]
|
||||
RE2["currentResponse()"]
|
||||
RE3["errorDialog()"]
|
||||
RE4["activeProcessingState()"]
|
||||
RE5["isChatStreaming()"]
|
||||
RE6["isChatLoading()"]
|
||||
RE7["getChatStreaming()"]
|
||||
RE8["getAllLoadingChats()"]
|
||||
RE9["getAllStreamingChats()"]
|
||||
RE9a["isEditModeActive()"]
|
||||
RE9b["getAddFilesHandler()"]
|
||||
RE9c["setEditModeActive()"]
|
||||
RE9d["clearEditMode()"]
|
||||
end
|
||||
subgraph AgenticExports["agenticStore"]
|
||||
REA1["agenticIsRunning()"]
|
||||
REA2["agenticCurrentTurn()"]
|
||||
REA3["agenticTotalToolCalls()"]
|
||||
REA4["agenticLastError()"]
|
||||
REA5["agenticStreamingToolCall()"]
|
||||
REA6["agenticIsAnyRunning()"]
|
||||
end
|
||||
subgraph ConvExports["conversationsStore"]
|
||||
RE10["conversations()"]
|
||||
RE11["activeConversation()"]
|
||||
RE12["activeMessages()"]
|
||||
RE13["isConversationsInitialized()"]
|
||||
end
|
||||
subgraph ModelsExports["modelsStore"]
|
||||
RE15["modelOptions()"]
|
||||
RE16["routerModels()"]
|
||||
RE17["modelsLoading()"]
|
||||
RE18["modelsUpdating()"]
|
||||
RE19["modelsError()"]
|
||||
RE20["selectedModelId()"]
|
||||
RE21["selectedModelName()"]
|
||||
RE22["selectedModelOption()"]
|
||||
RE23["loadedModelIds()"]
|
||||
RE24["loadingModelIds()"]
|
||||
RE25["propsCacheVersion()"]
|
||||
RE26["singleModelName()"]
|
||||
end
|
||||
subgraph ServerExports["serverStore"]
|
||||
RE27["serverProps()"]
|
||||
RE28["serverLoading()"]
|
||||
RE29["serverError()"]
|
||||
RE30["serverRole()"]
|
||||
RE31["defaultParams()"]
|
||||
RE32["contextSize()"]
|
||||
RE33["isRouterMode()"]
|
||||
RE34["isModelMode()"]
|
||||
end
|
||||
subgraph SettingsExports["settingsStore"]
|
||||
RE35["config()"]
|
||||
RE36["theme()"]
|
||||
RE37["isInitialized()"]
|
||||
end
|
||||
subgraph MCPExports["mcpStore / mcpResourceStore"]
|
||||
RE38["mcpResources()"]
|
||||
RE39["mcpResourceAttachments()"]
|
||||
RE40["mcpHasResourceAttachments()"]
|
||||
RE41["mcpTotalResourceCount()"]
|
||||
RE42["mcpResourcesLoading()"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
subgraph Services["⚙️ Services"]
|
||||
direction TB
|
||||
subgraph SV1["ChatService"]
|
||||
SV1Msg["<b>Messaging:</b><br/>sendMessage()"]
|
||||
SV1Stream["<b>Streaming:</b><br/>handleStreamResponse()<br/>handleNonStreamResponse()"]
|
||||
SV1Convert["<b>Conversion:</b><br/>convertDbMessageToApiChatMessageData()<br/>mergeToolCallDeltas()"]
|
||||
SV1Utils["<b>Utilities:</b><br/>stripReasoningContent()<br/>extractModelName()<br/>parseErrorResponse()"]
|
||||
end
|
||||
subgraph SV2["ModelsService"]
|
||||
SV2List["<b>Listing:</b><br/>list()<br/>listRouter()"]
|
||||
SV2LoadUnload["<b>Load/Unload:</b><br/>load()<br/>unload()"]
|
||||
SV2Status["<b>Status:</b><br/>isModelLoaded()<br/>isModelLoading()"]
|
||||
end
|
||||
subgraph SV3["PropsService"]
|
||||
SV3Fetch["<b>Fetching:</b><br/>fetch()<br/>fetchForModel()"]
|
||||
end
|
||||
subgraph SV4["DatabaseService"]
|
||||
SV4Conv["<b>Conversations:</b><br/>createConversation()<br/>getConversation()<br/>getAllConversations()<br/>updateConversation()<br/>deleteConversation()"]
|
||||
SV4Msg["<b>Messages:</b><br/>createMessageBranch()<br/>createRootMessage()<br/>createSystemMessage()<br/>getConversationMessages()<br/>updateMessage()<br/>deleteMessage()<br/>deleteMessageCascading()"]
|
||||
SV4Node["<b>Navigation:</b><br/>updateCurrentNode()"]
|
||||
SV4Import["<b>Import:</b><br/>importConversations()"]
|
||||
end
|
||||
subgraph SV5["ParameterSyncService"]
|
||||
SV5Extract["<b>Extraction:</b><br/>extractServerDefaults()"]
|
||||
SV5Merge["<b>Merging:</b><br/>mergeWithServerDefaults()"]
|
||||
SV5Info["<b>Info:</b><br/>getParameterInfo()<br/>canSyncParameter()<br/>getSyncableParameterKeys()<br/>validateServerParameter()"]
|
||||
SV5Diff["<b>Diff:</b><br/>createParameterDiff()"]
|
||||
end
|
||||
subgraph SV6["MCPService"]
|
||||
SV6Transport["<b>Transport:</b><br/>createTransport()<br/>WebSocket / StreamableHTTP / SSE"]
|
||||
SV6Conn["<b>Connection:</b><br/>connect()<br/>disconnect()"]
|
||||
SV6Tools["<b>Tools:</b><br/>listTools()<br/>callTool()"]
|
||||
SV6Prompts["<b>Prompts:</b><br/>listPrompts()<br/>getPrompt()"]
|
||||
SV6Resources["<b>Resources:</b><br/>listResources()<br/>listResourceTemplates()<br/>readResource()<br/>subscribeResource()<br/>unsubscribeResource()"]
|
||||
SV6Complete["<b>Completions:</b><br/>complete()"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph ExternalMCP["🔌 External MCP Servers"]
|
||||
EXT1["MCP Server 1<br/>(WebSocket/StreamableHTTP/SSE)"]
|
||||
EXT2["MCP Server N"]
|
||||
end
|
||||
|
||||
subgraph Storage["💾 Storage"]
|
||||
ST1["IndexedDB"]
|
||||
ST2["conversations"]
|
||||
ST3["messages"]
|
||||
ST5["LocalStorage"]
|
||||
ST6["config"]
|
||||
ST7["userOverrides"]
|
||||
ST8["mcpServers"]
|
||||
end
|
||||
|
||||
subgraph APIs["🌐 llama-server API"]
|
||||
API1["/v1/chat/completions"]
|
||||
API2["/props<br/>/props?model="]
|
||||
API3["/models<br/>/models/load<br/>/models/unload"]
|
||||
API4["/v1/models"]
|
||||
end
|
||||
|
||||
%% Routes render Components
|
||||
R1 --> C_Screen
|
||||
R2 --> C_Screen
|
||||
RL --> C_Sidebar
|
||||
|
||||
%% Layout runs MCP health checks on startup
|
||||
RL --> S6
|
||||
|
||||
%% Component hierarchy
|
||||
C_Screen --> C_Form & C_Messages & C_Settings
|
||||
C_Messages --> C_Message
|
||||
C_Message --> C_MessageUser
|
||||
C_MessageUser --> C_MessageEditForm
|
||||
C_MessageEditForm --> C_ModelsSelector
|
||||
C_MessageEditForm --> C_Attach
|
||||
C_Form --> C_ModelsSelector
|
||||
C_Form --> C_Attach
|
||||
C_Form --> C_McpServersSelector
|
||||
C_Message --> C_Attach
|
||||
|
||||
%% MCP Components hierarchy
|
||||
C_Settings --> C_McpSettings
|
||||
C_McpSettings --> C_McpServerCard
|
||||
C_McpServerCard --> C_McpResourceBrowser
|
||||
C_McpResourceBrowser --> C_McpResourcePreview
|
||||
|
||||
%% Components use Hooks
|
||||
C_Form --> H1
|
||||
C_Message --> H1 & H2
|
||||
C_MessageEditForm --> H1
|
||||
C_Screen --> H2
|
||||
|
||||
%% Hooks use Stores
|
||||
H1 --> S3 & S4
|
||||
H2 --> S1 & S5
|
||||
|
||||
%% Components use Stores
|
||||
C_Screen --> S1 & S2
|
||||
C_Messages --> S2
|
||||
C_Message --> S1 & S2 & S3
|
||||
C_Form --> S1 & S3 & S6
|
||||
C_Sidebar --> S2
|
||||
C_ModelsSelector --> S3 & S4
|
||||
C_Settings --> S5
|
||||
C_McpSettings --> S6
|
||||
C_McpServerCard --> S6
|
||||
C_McpResourceBrowser --> S6 & S7
|
||||
C_McpServersSelector --> S6
|
||||
|
||||
%% Stores export Reactive State
|
||||
S1 -. exports .-> ChatExports
|
||||
SA -. exports .-> AgenticExports
|
||||
S2 -. exports .-> ConvExports
|
||||
S3 -. exports .-> ModelsExports
|
||||
S4 -. exports .-> ServerExports
|
||||
S5 -. exports .-> SettingsExports
|
||||
S6 -. exports .-> MCPExports
|
||||
S7 -. exports .-> MCPExports
|
||||
|
||||
%% chatStore → agenticStore (agentic loop orchestration)
|
||||
S1 --> SA
|
||||
SA --> SV1
|
||||
SA --> S6
|
||||
|
||||
%% Stores use Services
|
||||
S1 --> SV1 & SV4
|
||||
S2 --> SV4
|
||||
S3 --> SV2 & SV3
|
||||
S4 --> SV3
|
||||
S5 --> SV5
|
||||
S6 --> SV6
|
||||
S7 --> SV6
|
||||
|
||||
%% Services to Storage
|
||||
SV4 --> ST1
|
||||
ST1 --> ST2 & ST3
|
||||
SV5 --> ST5
|
||||
ST5 --> ST6 & ST7 & ST8
|
||||
|
||||
%% Services to APIs
|
||||
SV1 --> API1
|
||||
SV2 --> API3 & API4
|
||||
SV3 --> API2
|
||||
|
||||
%% MCP → External Servers
|
||||
SV6 --> EXT1 & EXT2
|
||||
|
||||
%% Styling
|
||||
classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px
|
||||
classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
|
||||
classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px
|
||||
classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px
|
||||
classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px
|
||||
classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px
|
||||
classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px
|
||||
classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px
|
||||
classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
|
||||
classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px
|
||||
classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5
|
||||
classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px
|
||||
classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
|
||||
|
||||
class R1,R2,RL routeStyle
|
||||
class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle
|
||||
class C_ModelsSelector,C_Settings componentStyle
|
||||
class C_Attach componentStyle
|
||||
class C_McpSettings,C_McpServerCard,C_McpResourceBrowser,C_McpResourcePreview,C_McpServersSelector componentStyle
|
||||
class H1,H2,H3 hookStyle
|
||||
class LayoutComponents,ChatUIComponents,MCPComponents componentGroupStyle
|
||||
class Hooks hookStyle
|
||||
classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px
|
||||
classDef agenticMethodStyle fill:#c5cae9,stroke:#283593,stroke-width:1px
|
||||
|
||||
class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle
|
||||
class S1State,S2State,S3State,S4State,S5State,SAState,S6State,S7State stateStyle
|
||||
class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle
|
||||
class SASession,SAConfig,SAFlow methodStyle
|
||||
class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2McpOverrides,S2Export,S2Utils methodStyle
|
||||
class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle
|
||||
class S4Getters,S4Data,S4Utils methodStyle
|
||||
class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle
|
||||
class S6Lifecycle,S6Health,S6Servers,S6Tools,S6Prompts methodStyle
|
||||
class S7Resources,S7Cache,S7Subs,S7Attach methodStyle
|
||||
class ChatExports,AgenticExports,ConvExports,ModelsExports,ServerExports,SettingsExports,MCPExports reactiveStyle
|
||||
class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle
|
||||
class SV6Transport,SV6Conn,SV6Tools,SV6Prompts,SV6Resources,SV6Complete serviceMStyle
|
||||
class EXT1,EXT2 externalStyle
|
||||
class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle
|
||||
class SV2List,SV2LoadUnload,SV2Status serviceMStyle
|
||||
class SV3Fetch serviceMStyle
|
||||
class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle
|
||||
class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle
|
||||
class ST1,ST2,ST3,ST5,ST6,ST7,ST8 storageStyle
|
||||
class API1,API2,API3,API4 apiStyle
|
||||
```
|
||||
228
tools/server/webui/docs/flows/chat-flow.md
Normal file
228
tools/server/webui/docs/flows/chat-flow.md
Normal file
@@ -0,0 +1,228 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 ChatForm / ChatMessage
|
||||
participant chatStore as 🗄️ chatStore
|
||||
participant agenticStore as 🗄️ agenticStore
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant settingsStore as 🗄️ settingsStore
|
||||
participant mcpStore as 🗄️ mcpStore
|
||||
participant ChatSvc as ⚙️ ChatService
|
||||
participant DbSvc as ⚙️ DatabaseService
|
||||
participant API as 🌐 /v1/chat/completions
|
||||
|
||||
Note over chatStore: State:<br/>isLoading, currentResponse<br/>errorDialogState, activeProcessingState<br/>chatLoadingStates (Map)<br/>chatStreamingStates (Map)<br/>abortControllers (Map)<br/>processingStates (Map)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 💬 SEND MESSAGE
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: sendMessage(content, extras)
|
||||
activate chatStore
|
||||
|
||||
chatStore->>chatStore: setChatLoading(convId, true)
|
||||
chatStore->>chatStore: clearChatStreaming(convId)
|
||||
|
||||
alt no active conversation
|
||||
chatStore->>convStore: createConversation()
|
||||
Note over convStore: → see conversations-flow.mmd
|
||||
end
|
||||
|
||||
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
|
||||
Note right of mcpStore: Converts pending MCP resource<br/>attachments into message extras
|
||||
|
||||
chatStore->>chatStore: addMessage("user", content, extras)
|
||||
chatStore->>DbSvc: createMessageBranch(userMsg, parentId)
|
||||
chatStore->>convStore: addMessageToActive(userMsg)
|
||||
chatStore->>convStore: updateCurrentNode(userMsg.id)
|
||||
|
||||
chatStore->>chatStore: createAssistantMessage(userMsg.id)
|
||||
chatStore->>DbSvc: createMessageBranch(assistantMsg, userMsg.id)
|
||||
chatStore->>convStore: addMessageToActive(assistantMsg)
|
||||
|
||||
chatStore->>chatStore: streamChatCompletion(messages, assistantMsg)
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🌊 STREAMING (with agentic flow detection)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
activate chatStore
|
||||
chatStore->>chatStore: startStreaming()
|
||||
Note right of chatStore: isStreamingActive = true
|
||||
|
||||
chatStore->>chatStore: setActiveProcessingConversation(convId)
|
||||
chatStore->>chatStore: getOrCreateAbortController(convId)
|
||||
Note right of chatStore: abortControllers.set(convId, new AbortController())
|
||||
|
||||
chatStore->>chatStore: getApiOptions()
|
||||
Note right of chatStore: Merge from settingsStore.config:<br/>temperature, max_tokens, top_p, etc.
|
||||
|
||||
alt agenticConfig.enabled && mcpStore has connected servers
|
||||
chatStore->>agenticStore: runAgenticFlow(convId, messages, assistantMsg, options, signal)
|
||||
Note over agenticStore: Multi-turn agentic loop:<br/>1. Call ChatService.sendMessage()<br/>2. If response has tool_calls → execute via mcpStore<br/>3. Append tool results as messages<br/>4. Loop until no more tool_calls or maxTurns<br/>→ see agentic flow details below
|
||||
agenticStore-->>chatStore: final response with timings
|
||||
else standard (non-agentic) flow
|
||||
chatStore->>ChatSvc: sendMessage(messages, options, signal)
|
||||
end
|
||||
|
||||
activate ChatSvc
|
||||
|
||||
ChatSvc->>ChatSvc: convertDbMessageToApiChatMessageData(messages)
|
||||
Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]<br/>Process attachments (images, PDFs, audio)
|
||||
|
||||
ChatSvc->>API: POST /v1/chat/completions
|
||||
Note right of API: {messages, model?, stream: true, ...params}
|
||||
|
||||
loop SSE chunks
|
||||
API-->>ChatSvc: data: {"choices":[{"delta":{...}}]}
|
||||
ChatSvc->>ChatSvc: handleStreamResponse(response)
|
||||
|
||||
alt content chunk
|
||||
ChatSvc-->>chatStore: onChunk(content)
|
||||
chatStore->>chatStore: setChatStreaming(convId, response, msgId)
|
||||
Note right of chatStore: currentResponse = $state(accumulated)
|
||||
chatStore->>convStore: updateMessageAtIndex(idx, {content})
|
||||
end
|
||||
|
||||
alt reasoning chunk
|
||||
ChatSvc-->>chatStore: onReasoningChunk(reasoning)
|
||||
chatStore->>convStore: updateMessageAtIndex(idx, {thinking})
|
||||
end
|
||||
|
||||
alt tool_calls chunk
|
||||
ChatSvc-->>chatStore: onToolCallChunk(toolCalls)
|
||||
chatStore->>convStore: updateMessageAtIndex(idx, {toolCalls})
|
||||
end
|
||||
|
||||
alt model info
|
||||
ChatSvc-->>chatStore: onModel(modelName)
|
||||
chatStore->>chatStore: recordModel(modelName)
|
||||
chatStore->>DbSvc: updateMessage(msgId, {model})
|
||||
end
|
||||
|
||||
alt timings (during stream)
|
||||
ChatSvc-->>chatStore: onTimings(timings, promptProgress)
|
||||
chatStore->>chatStore: updateProcessingStateFromTimings()
|
||||
end
|
||||
|
||||
chatStore-->>UI: reactive $state update
|
||||
end
|
||||
|
||||
API-->>ChatSvc: data: [DONE]
|
||||
ChatSvc-->>chatStore: onComplete(content, reasoning, timings, toolCalls)
|
||||
deactivate ChatSvc
|
||||
|
||||
chatStore->>chatStore: stopStreaming()
|
||||
chatStore->>DbSvc: updateMessage(msgId, {content, timings, model})
|
||||
chatStore->>convStore: updateCurrentNode(msgId)
|
||||
chatStore->>chatStore: setChatLoading(convId, false)
|
||||
chatStore->>chatStore: clearChatStreaming(convId)
|
||||
chatStore->>chatStore: clearProcessingState(convId)
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ⏹️ STOP GENERATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: stopGeneration()
|
||||
activate chatStore
|
||||
chatStore->>chatStore: savePartialResponseIfNeeded(convId)
|
||||
Note right of chatStore: Save currentResponse to DB if non-empty
|
||||
chatStore->>chatStore: abortControllers.get(convId).abort()
|
||||
Note right of chatStore: fetch throws AbortError → caught by isAbortError()
|
||||
chatStore->>chatStore: stopStreaming()
|
||||
chatStore->>chatStore: setChatLoading(convId, false)
|
||||
chatStore->>chatStore: clearChatStreaming(convId)
|
||||
chatStore->>chatStore: clearProcessingState(convId)
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🔁 REGENERATE
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: regenerateMessageWithBranching(msgId, model?)
|
||||
activate chatStore
|
||||
chatStore->>convStore: findMessageIndex(msgId)
|
||||
chatStore->>chatStore: Get parent of target message
|
||||
chatStore->>chatStore: createAssistantMessage(parentId)
|
||||
chatStore->>DbSvc: createMessageBranch(newAssistantMsg, parentId)
|
||||
chatStore->>convStore: refreshActiveMessages()
|
||||
Note right of chatStore: Same streaming flow
|
||||
chatStore->>chatStore: streamChatCompletion(...)
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ➡️ CONTINUE
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: continueAssistantMessage(msgId)
|
||||
activate chatStore
|
||||
chatStore->>chatStore: Get existing content from message
|
||||
chatStore->>chatStore: streamChatCompletion(..., existingContent)
|
||||
Note right of chatStore: Appends to existing message content
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ✏️ EDIT USER MESSAGE
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>chatStore: editMessageWithBranching(msgId, newContent, extras)
|
||||
activate chatStore
|
||||
chatStore->>chatStore: Get parent of target message
|
||||
chatStore->>DbSvc: createMessageBranch(editedMsg, parentId)
|
||||
chatStore->>convStore: refreshActiveMessages()
|
||||
Note right of chatStore: Creates new branch, original preserved
|
||||
chatStore->>chatStore: createAssistantMessage(editedMsg.id)
|
||||
chatStore->>chatStore: streamChatCompletion(...)
|
||||
Note right of chatStore: Automatically regenerates response
|
||||
deactivate chatStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ❌ ERROR HANDLING
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over chatStore: On stream error (non-abort):
|
||||
chatStore->>chatStore: showErrorDialog(type, message)
|
||||
Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message}
|
||||
chatStore->>convStore: removeMessageAtIndex(failedMsgIdx)
|
||||
chatStore->>DbSvc: deleteMessage(failedMsgId)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🤖 AGENTIC LOOP (when agenticConfig.enabled)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over agenticStore: agenticStore.runAgenticFlow(convId, messages, assistantMsg, options, signal)
|
||||
activate agenticStore
|
||||
agenticStore->>agenticStore: getSession(convId) or create new
|
||||
agenticStore->>agenticStore: updateSession(turn: 0, running: true)
|
||||
|
||||
loop executeAgenticLoop (until no tool_calls or maxTurns)
|
||||
agenticStore->>agenticStore: turn++
|
||||
agenticStore->>ChatSvc: sendMessage(messages, options, signal)
|
||||
ChatSvc->>API: POST /v1/chat/completions
|
||||
API-->>ChatSvc: response with potential tool_calls
|
||||
ChatSvc-->>agenticStore: onComplete(content, reasoning, timings, toolCalls)
|
||||
|
||||
alt response has tool_calls
|
||||
agenticStore->>agenticStore: normalizeToolCalls(toolCalls)
|
||||
loop for each tool_call
|
||||
agenticStore->>agenticStore: updateSession(streamingToolCall)
|
||||
agenticStore->>mcpStore: executeTool(mcpCall, signal)
|
||||
mcpStore-->>agenticStore: tool result
|
||||
agenticStore->>agenticStore: extractBase64Attachments(result)
|
||||
agenticStore->>agenticStore: emitToolCallResult(convId, ...)
|
||||
agenticStore->>convStore: addMessageToActive(toolResultMsg)
|
||||
agenticStore->>DbSvc: createMessageBranch(toolResultMsg)
|
||||
end
|
||||
agenticStore->>agenticStore: Create new assistantMsg for next turn
|
||||
Note right of agenticStore: Continue loop with updated messages
|
||||
else no tool_calls (final response)
|
||||
agenticStore->>agenticStore: buildFinalTimings(allTurns)
|
||||
Note right of agenticStore: Break loop, return final response
|
||||
end
|
||||
end
|
||||
|
||||
agenticStore->>agenticStore: updateSession(running: false)
|
||||
agenticStore-->>chatStore: final content, timings, model
|
||||
deactivate agenticStore
|
||||
```
|
||||
183
tools/server/webui/docs/flows/conversations-flow.md
Normal file
183
tools/server/webui/docs/flows/conversations-flow.md
Normal file
@@ -0,0 +1,183 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 ChatSidebar / ChatScreen
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant chatStore as 🗄️ chatStore
|
||||
participant DbSvc as ⚙️ DatabaseService
|
||||
participant IDB as 💾 IndexedDB
|
||||
|
||||
Note over convStore: State:<br/>conversations: DatabaseConversation[]<br/>activeConversation: DatabaseConversation | null<br/>activeMessages: DatabaseMessage[]<br/>isInitialized: boolean<br/>pendingMcpServerOverrides: Map<string, McpServerOverride>
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 🚀 INITIALIZATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over convStore: Auto-initialized in constructor (browser only)
|
||||
convStore->>convStore: initialize()
|
||||
activate convStore
|
||||
convStore->>convStore: loadConversations()
|
||||
convStore->>DbSvc: getAllConversations()
|
||||
DbSvc->>IDB: SELECT * FROM conversations ORDER BY lastModified DESC
|
||||
IDB-->>DbSvc: Conversation[]
|
||||
DbSvc-->>convStore: conversations
|
||||
convStore->>convStore: conversations = $state(data)
|
||||
convStore->>convStore: isInitialized = true
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: ➕ CREATE CONVERSATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: createConversation(name?)
|
||||
activate convStore
|
||||
convStore->>DbSvc: createConversation(name || "New Chat")
|
||||
DbSvc->>IDB: INSERT INTO conversations
|
||||
IDB-->>DbSvc: conversation {id, name, lastModified, currNode: ""}
|
||||
DbSvc-->>convStore: conversation
|
||||
convStore->>convStore: conversations.unshift(conversation)
|
||||
convStore->>convStore: activeConversation = $state(conversation)
|
||||
convStore->>convStore: activeMessages = $state([])
|
||||
|
||||
alt pendingMcpServerOverrides has entries
|
||||
loop each pending override
|
||||
convStore->>DbSvc: Store MCP server override for new conversation
|
||||
end
|
||||
convStore->>convStore: clearPendingMcpServerOverrides()
|
||||
end
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 📂 LOAD CONVERSATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: loadConversation(convId)
|
||||
activate convStore
|
||||
convStore->>DbSvc: getConversation(convId)
|
||||
DbSvc->>IDB: SELECT * FROM conversations WHERE id = ?
|
||||
IDB-->>DbSvc: conversation
|
||||
convStore->>convStore: activeConversation = $state(conversation)
|
||||
|
||||
convStore->>convStore: refreshActiveMessages()
|
||||
convStore->>DbSvc: getConversationMessages(convId)
|
||||
DbSvc->>IDB: SELECT * FROM messages WHERE convId = ?
|
||||
IDB-->>DbSvc: allMessages[]
|
||||
convStore->>convStore: filterByLeafNodeId(allMessages, currNode)
|
||||
Note right of convStore: Filter to show only current branch path
|
||||
convStore->>convStore: activeMessages = $state(filtered)
|
||||
|
||||
Note right of convStore: Route (+page.svelte) then calls:<br/>chatStore.syncLoadingStateForChat(convId)
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 🌳 MESSAGE BRANCHING MODEL
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over IDB: Message Tree Structure:<br/>- Each message has parent (null for root)<br/>- Each message has children[] array<br/>- Conversation.currNode points to active leaf<br/>- filterByLeafNodeId() traverses from root to currNode
|
||||
|
||||
rect rgb(240, 240, 255)
|
||||
Note over convStore: Example Branch Structure:
|
||||
Note over convStore: root → user1 → assistant1 → user2 → assistant2a (currNode)<br/> ↘ assistant2b (alt branch)
|
||||
end
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: ↔️ BRANCH NAVIGATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: navigateToSibling(msgId, direction)
|
||||
activate convStore
|
||||
convStore->>convStore: Find message in activeMessages
|
||||
convStore->>convStore: Get parent message
|
||||
convStore->>convStore: Find sibling in parent.children[]
|
||||
convStore->>convStore: findLeafNode(siblingId, allMessages)
|
||||
Note right of convStore: Navigate to leaf of sibling branch
|
||||
convStore->>convStore: updateCurrentNode(leafId)
|
||||
convStore->>DbSvc: updateCurrentNode(convId, leafId)
|
||||
DbSvc->>IDB: UPDATE conversations SET currNode = ?
|
||||
convStore->>convStore: refreshActiveMessages()
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 📝 UPDATE CONVERSATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: updateConversationName(convId, newName)
|
||||
activate convStore
|
||||
convStore->>DbSvc: updateConversation(convId, {name: newName})
|
||||
DbSvc->>IDB: UPDATE conversations SET name = ?
|
||||
convStore->>convStore: Update in conversations array
|
||||
deactivate convStore
|
||||
|
||||
Note over convStore: Auto-title update (after first response):
|
||||
convStore->>convStore: updateConversationTitleWithConfirmation()
|
||||
convStore->>convStore: titleUpdateConfirmationCallback?()
|
||||
Note right of convStore: Shows dialog if title would change
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 🗑️ DELETE CONVERSATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: deleteConversation(convId)
|
||||
activate convStore
|
||||
convStore->>DbSvc: deleteConversation(convId)
|
||||
DbSvc->>IDB: DELETE FROM conversations WHERE id = ?
|
||||
DbSvc->>IDB: DELETE FROM messages WHERE convId = ?
|
||||
convStore->>convStore: conversations.filter(c => c.id !== convId)
|
||||
alt deleted active conversation
|
||||
convStore->>convStore: clearActiveConversation()
|
||||
end
|
||||
deactivate convStore
|
||||
|
||||
UI->>convStore: deleteAll()
|
||||
activate convStore
|
||||
convStore->>DbSvc: Delete all conversations and messages
|
||||
convStore->>convStore: conversations = []
|
||||
convStore->>convStore: clearActiveConversation()
|
||||
deactivate convStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: <20> MCP SERVER PER-CHAT OVERRIDES
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over convStore: Conversations can override which MCP servers are enabled.
|
||||
Note over convStore: Uses pendingMcpServerOverrides before conversation<br/>is created, then persists to conversation metadata.
|
||||
|
||||
UI->>convStore: setMcpServerOverride(convId, serverName, override)
|
||||
Note right of convStore: override = {enabled: boolean}
|
||||
|
||||
UI->>convStore: toggleMcpServerForChat(convId, serverName, enabled)
|
||||
activate convStore
|
||||
convStore->>convStore: setMcpServerOverride(convId, serverName, {enabled})
|
||||
deactivate convStore
|
||||
|
||||
UI->>convStore: isMcpServerEnabledForChat(convId, serverName)
|
||||
Note right of convStore: Check override → fall back to global MCP config
|
||||
|
||||
UI->>convStore: getAllMcpServerOverrides(convId)
|
||||
Note right of convStore: Returns all overrides for a conversation
|
||||
|
||||
UI->>convStore: removeMcpServerOverride(convId, serverName)
|
||||
UI->>convStore: getMcpServerOverride(convId, serverName)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,IDB: 📤 EXPORT / 📥 IMPORT
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>convStore: exportAllConversations()
|
||||
activate convStore
|
||||
convStore->>DbSvc: getAllConversations()
|
||||
loop each conversation
|
||||
convStore->>DbSvc: getConversationMessages(convId)
|
||||
end
|
||||
convStore->>convStore: triggerDownload(JSON blob)
|
||||
deactivate convStore
|
||||
|
||||
UI->>convStore: importConversations(file)
|
||||
activate convStore
|
||||
convStore->>convStore: Parse JSON file
|
||||
convStore->>convStore: importConversationsData(parsed)
|
||||
convStore->>DbSvc: importConversations(parsed)
|
||||
Note right of DbSvc: Skips duplicate conversations<br/>(checks existing by ID)
|
||||
DbSvc->>IDB: INSERT conversations + messages (skip existing)
|
||||
convStore->>convStore: loadConversations()
|
||||
deactivate convStore
|
||||
```
|
||||
@@ -0,0 +1,45 @@
|
||||
```mermaid
|
||||
%% MODEL Mode Data Flow (single model)
|
||||
%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd
|
||||
|
||||
sequenceDiagram
|
||||
participant User as 👤 User
|
||||
participant UI as 🧩 UI
|
||||
participant Stores as 🗄️ Stores
|
||||
participant DB as 💾 IndexedDB
|
||||
participant API as 🌐 llama-server
|
||||
|
||||
Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd)
|
||||
|
||||
UI->>Stores: initialize()
|
||||
Stores->>DB: load conversations
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: server config + modalities
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: single model (auto-selected)
|
||||
|
||||
Note over User,API: 💬 Chat Flow (see: chat-flow.mmd)
|
||||
|
||||
User->>UI: send message
|
||||
UI->>Stores: sendMessage()
|
||||
Stores->>DB: save user message
|
||||
Stores->>API: POST /v1/chat/completions (stream)
|
||||
loop streaming
|
||||
API-->>Stores: SSE chunks
|
||||
Stores-->>UI: reactive update
|
||||
end
|
||||
API-->>Stores: done + timings
|
||||
Stores->>DB: save assistant message
|
||||
|
||||
Note over User,API: 🔁 Regenerate
|
||||
|
||||
User->>UI: regenerate
|
||||
Stores->>DB: create message branch
|
||||
Note right of Stores: same streaming flow
|
||||
|
||||
Note over User,API: ⏹️ Stop
|
||||
|
||||
User->>UI: stop
|
||||
Stores->>Stores: abort stream
|
||||
Stores->>DB: save partial response
|
||||
```
|
||||
@@ -0,0 +1,77 @@
|
||||
```mermaid
|
||||
%% ROUTER Mode Data Flow (multi-model)
|
||||
%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd
|
||||
|
||||
sequenceDiagram
|
||||
participant User as 👤 User
|
||||
participant UI as 🧩 UI
|
||||
participant Stores as 🗄️ Stores
|
||||
participant DB as 💾 IndexedDB
|
||||
participant API as 🌐 llama-server
|
||||
|
||||
Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd)
|
||||
|
||||
UI->>Stores: initialize()
|
||||
Stores->>DB: load conversations
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: {role: "router"}
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: models[] with status (loaded/available)
|
||||
loop each loaded model
|
||||
Stores->>API: GET /props?model=X
|
||||
API-->>Stores: modalities (vision/audio)
|
||||
end
|
||||
|
||||
Note over User,API: 🔄 Model Selection (see: models-flow.mmd)
|
||||
|
||||
User->>UI: select model
|
||||
alt model not loaded
|
||||
Stores->>API: POST /models/load
|
||||
loop poll status
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: check if loaded
|
||||
end
|
||||
Stores->>API: GET /props?model=X
|
||||
API-->>Stores: cache modalities
|
||||
end
|
||||
Stores->>Stores: validate modalities vs conversation
|
||||
alt valid
|
||||
Stores->>Stores: select model
|
||||
else invalid
|
||||
Stores->>API: POST /models/unload
|
||||
UI->>User: show error toast
|
||||
end
|
||||
|
||||
Note over User,API: 💬 Chat Flow (see: chat-flow.mmd)
|
||||
|
||||
User->>UI: send message
|
||||
UI->>Stores: sendMessage()
|
||||
Stores->>DB: save user message
|
||||
Stores->>API: POST /v1/chat/completions {model: X}
|
||||
Note right of API: router forwards to model
|
||||
loop streaming
|
||||
API-->>Stores: SSE chunks + model info
|
||||
Stores-->>UI: reactive update
|
||||
end
|
||||
API-->>Stores: done + timings
|
||||
Stores->>DB: save assistant message + model used
|
||||
|
||||
Note over User,API: 🔁 Regenerate (optional: different model)
|
||||
|
||||
User->>UI: regenerate
|
||||
Stores->>Stores: validate modalities up to this message
|
||||
Stores->>DB: create message branch
|
||||
Note right of Stores: same streaming flow
|
||||
|
||||
Note over User,API: ⏹️ Stop
|
||||
|
||||
User->>UI: stop
|
||||
Stores->>Stores: abort stream
|
||||
Stores->>DB: save partial response
|
||||
|
||||
Note over User,API: 🗑️ LRU Unloading
|
||||
|
||||
Note right of API: Server auto-unloads LRU models<br/>when cache full
|
||||
User->>UI: select unloaded model
|
||||
Note right of Stores: triggers load flow again
|
||||
```
|
||||
174
tools/server/webui/docs/flows/database-flow.md
Normal file
174
tools/server/webui/docs/flows/database-flow.md
Normal file
@@ -0,0 +1,174 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Store as 🗄️ Stores
|
||||
participant DbSvc as ⚙️ DatabaseService
|
||||
participant Dexie as 📦 Dexie ORM
|
||||
participant IDB as 💾 IndexedDB
|
||||
|
||||
Note over DbSvc: Stateless service - all methods static<br/>Database: "LlamacppWebui"
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 📊 SCHEMA
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
rect rgb(240, 248, 255)
|
||||
Note over IDB: conversations table:<br/>id (PK), lastModified, currNode, name
|
||||
end
|
||||
|
||||
rect rgb(255, 248, 240)
|
||||
Note over IDB: messages table:<br/>id (PK), convId (FK), type, role, timestamp,<br/>parent, children[], content, thinking,<br/>toolCalls, extra[], model, timings
|
||||
end
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 💬 CONVERSATIONS CRUD
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Store->>DbSvc: createConversation(name)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Generate UUID
|
||||
DbSvc->>Dexie: db.conversations.add({id, name, lastModified, currNode: ""})
|
||||
Dexie->>IDB: INSERT
|
||||
IDB-->>Dexie: success
|
||||
DbSvc-->>Store: DatabaseConversation
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: getConversation(convId)
|
||||
DbSvc->>Dexie: db.conversations.get(convId)
|
||||
Dexie->>IDB: SELECT WHERE id = ?
|
||||
IDB-->>DbSvc: DatabaseConversation
|
||||
|
||||
Store->>DbSvc: getAllConversations()
|
||||
DbSvc->>Dexie: db.conversations.orderBy('lastModified').reverse().toArray()
|
||||
Dexie->>IDB: SELECT ORDER BY lastModified DESC
|
||||
IDB-->>DbSvc: DatabaseConversation[]
|
||||
|
||||
Store->>DbSvc: updateConversation(convId, updates)
|
||||
DbSvc->>Dexie: db.conversations.update(convId, {...updates, lastModified})
|
||||
Dexie->>IDB: UPDATE
|
||||
|
||||
Store->>DbSvc: deleteConversation(convId)
|
||||
activate DbSvc
|
||||
DbSvc->>Dexie: db.conversations.delete(convId)
|
||||
Dexie->>IDB: DELETE FROM conversations
|
||||
DbSvc->>Dexie: db.messages.where('convId').equals(convId).delete()
|
||||
Dexie->>IDB: DELETE FROM messages WHERE convId = ?
|
||||
deactivate DbSvc
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 📝 MESSAGES CRUD
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Store->>DbSvc: createRootMessage(convId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Create root message {type: "root", parent: null}
|
||||
DbSvc->>Dexie: db.messages.add(rootMsg)
|
||||
Dexie->>IDB: INSERT
|
||||
DbSvc-->>Store: rootMessageId
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: createSystemMessage(convId, content, parentId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Create message {role: "system", parent: parentId}
|
||||
DbSvc->>Dexie: db.messages.add(systemMsg)
|
||||
Dexie->>IDB: INSERT
|
||||
DbSvc-->>Store: DatabaseMessage
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: createMessageBranch(message, parentId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: Generate UUID for new message
|
||||
DbSvc->>Dexie: db.messages.add({...message, id, parent: parentId})
|
||||
Dexie->>IDB: INSERT message
|
||||
|
||||
alt parentId exists
|
||||
DbSvc->>Dexie: db.messages.get(parentId)
|
||||
Dexie->>IDB: SELECT parent
|
||||
DbSvc->>DbSvc: parent.children.push(newId)
|
||||
DbSvc->>Dexie: db.messages.update(parentId, {children})
|
||||
Dexie->>IDB: UPDATE parent.children
|
||||
end
|
||||
|
||||
DbSvc->>Dexie: db.conversations.update(convId, {currNode: newId})
|
||||
Dexie->>IDB: UPDATE conversation.currNode
|
||||
DbSvc-->>Store: DatabaseMessage
|
||||
deactivate DbSvc
|
||||
|
||||
Store->>DbSvc: getConversationMessages(convId)
|
||||
DbSvc->>Dexie: db.messages.where('convId').equals(convId).toArray()
|
||||
Dexie->>IDB: SELECT WHERE convId = ?
|
||||
IDB-->>DbSvc: DatabaseMessage[]
|
||||
|
||||
Store->>DbSvc: updateMessage(msgId, updates)
|
||||
DbSvc->>Dexie: db.messages.update(msgId, updates)
|
||||
Dexie->>IDB: UPDATE
|
||||
|
||||
Store->>DbSvc: deleteMessage(msgId)
|
||||
DbSvc->>Dexie: db.messages.delete(msgId)
|
||||
Dexie->>IDB: DELETE
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 🌳 BRANCHING OPERATIONS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Store->>DbSvc: updateCurrentNode(convId, nodeId)
|
||||
DbSvc->>Dexie: db.conversations.update(convId, {currNode: nodeId, lastModified})
|
||||
Dexie->>IDB: UPDATE
|
||||
|
||||
Store->>DbSvc: deleteMessageCascading(msgId)
|
||||
activate DbSvc
|
||||
DbSvc->>DbSvc: findDescendantMessages(msgId, allMessages)
|
||||
Note right of DbSvc: Recursively find all children
|
||||
loop each descendant
|
||||
DbSvc->>Dexie: db.messages.delete(descendantId)
|
||||
Dexie->>IDB: DELETE
|
||||
end
|
||||
DbSvc->>Dexie: db.messages.delete(msgId)
|
||||
Dexie->>IDB: DELETE target message
|
||||
|
||||
alt target message has a parent
|
||||
DbSvc->>Dexie: db.messages.get(parentId)
|
||||
DbSvc->>DbSvc: parent.children.filter(id !== msgId)
|
||||
DbSvc->>Dexie: db.messages.update(parentId, {children})
|
||||
Note right of DbSvc: Remove deleted message from parent's children[]
|
||||
end
|
||||
deactivate DbSvc
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 📥 IMPORT
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Store->>DbSvc: importConversations(data)
|
||||
activate DbSvc
|
||||
loop each conversation in data
|
||||
DbSvc->>Dexie: db.conversations.get(conv.id)
|
||||
alt conversation already exists
|
||||
Note right of DbSvc: Skip duplicate (keep existing)
|
||||
else conversation is new
|
||||
DbSvc->>Dexie: db.conversations.add(conversation)
|
||||
Dexie->>IDB: INSERT conversation
|
||||
loop each message
|
||||
DbSvc->>Dexie: db.messages.add(message)
|
||||
Dexie->>IDB: INSERT message
|
||||
end
|
||||
end
|
||||
end
|
||||
deactivate DbSvc
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over Store,IDB: 🔗 MESSAGE TREE UTILITIES
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over DbSvc: Used by stores (imported from utils):
|
||||
|
||||
rect rgb(240, 255, 240)
|
||||
Note over DbSvc: filterByLeafNodeId(messages, leafId)<br/>→ Returns path from root to leaf<br/>→ Used to display current branch
|
||||
end
|
||||
|
||||
rect rgb(240, 255, 240)
|
||||
Note over DbSvc: findLeafNode(startId, messages)<br/>→ Traverse to deepest child<br/>→ Used for branch navigation
|
||||
end
|
||||
|
||||
rect rgb(240, 255, 240)
|
||||
Note over DbSvc: findDescendantMessages(msgId, messages)<br/>→ Find all children recursively<br/>→ Used for cascading deletes
|
||||
end
|
||||
```
|
||||
226
tools/server/webui/docs/flows/mcp-flow.md
Normal file
226
tools/server/webui/docs/flows/mcp-flow.md
Normal file
@@ -0,0 +1,226 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 McpServersSettings / ChatForm
|
||||
participant chatStore as 🗄️ chatStore
|
||||
participant mcpStore as 🗄️ mcpStore
|
||||
participant mcpResStore as 🗄️ mcpResourceStore
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant MCPSvc as ⚙️ MCPService
|
||||
participant LS as 💾 LocalStorage
|
||||
participant ExtMCP as 🔌 External MCP Server
|
||||
|
||||
Note over mcpStore: State:<br/>isInitializing, error<br/>toolCount, connectedServers<br/>healthChecks (Map)<br/>connections (Map)<br/>toolsIndex (Map)<br/>serverConfigs (Map)
|
||||
|
||||
Note over mcpResStore: State:<br/>serverResources (Map)<br/>cachedResources (Map)<br/>subscriptions (Map)<br/>attachments[]
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🚀 INITIALIZATION (App Startup)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: ensureInitialized()
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>LS: get(MCP_SERVERS_LOCALSTORAGE_KEY)
|
||||
LS-->>mcpStore: MCPServerSettingsEntry[]
|
||||
|
||||
mcpStore->>mcpStore: parseServerSettings(servers)
|
||||
Note right of mcpStore: Filter enabled servers<br/>Build MCPServerConfig objects<br/>Per-chat overrides checked via convStore
|
||||
|
||||
loop For each enabled server
|
||||
mcpStore->>mcpStore: runHealthCheck(serverId)
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, CONNECTING)
|
||||
|
||||
mcpStore->>MCPSvc: connect(serverName, config, clientInfo, capabilities, onPhase)
|
||||
activate MCPSvc
|
||||
|
||||
MCPSvc->>MCPSvc: createTransport(config)
|
||||
Note right of MCPSvc: WebSocket / StreamableHTTP / SSE<br/>with optional CORS proxy
|
||||
|
||||
MCPSvc->>ExtMCP: Transport handshake
|
||||
ExtMCP-->>MCPSvc: Connection established
|
||||
|
||||
MCPSvc->>ExtMCP: Initialize request
|
||||
Note right of ExtMCP: Exchange capabilities<br/>Server info, protocol version
|
||||
|
||||
ExtMCP-->>MCPSvc: InitializeResult (serverInfo, capabilities)
|
||||
|
||||
MCPSvc->>ExtMCP: listTools()
|
||||
ExtMCP-->>MCPSvc: Tool[]
|
||||
|
||||
MCPSvc-->>mcpStore: MCPConnection
|
||||
deactivate MCPSvc
|
||||
|
||||
mcpStore->>mcpStore: connections.set(serverName, connection)
|
||||
mcpStore->>mcpStore: indexTools(connection.tools, serverName)
|
||||
Note right of mcpStore: toolsIndex.set(toolName, serverName)<br/>Handle name conflicts with prefixes
|
||||
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
|
||||
mcpStore->>mcpStore: _connectedServers.push(serverName)
|
||||
|
||||
alt Server supports resources
|
||||
mcpStore->>MCPSvc: listAllResources(connection)
|
||||
MCPSvc->>ExtMCP: listResources()
|
||||
ExtMCP-->>MCPSvc: MCPResource[]
|
||||
MCPSvc-->>mcpStore: resources
|
||||
|
||||
mcpStore->>MCPSvc: listAllResourceTemplates(connection)
|
||||
MCPSvc->>ExtMCP: listResourceTemplates()
|
||||
ExtMCP-->>MCPSvc: MCPResourceTemplate[]
|
||||
MCPSvc-->>mcpStore: templates
|
||||
|
||||
mcpStore->>mcpResStore: setServerResources(serverName, resources, templates)
|
||||
end
|
||||
end
|
||||
|
||||
mcpStore->>mcpStore: _isInitializing = false
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🔧 TOOL EXECUTION (Chat with Tools)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: executeTool(mcpCall: MCPToolCall, signal?)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: toolsIndex.get(mcpCall.function.name)
|
||||
Note right of mcpStore: Resolve serverName from toolsIndex<br/>MCPToolCall = {id, type, function: {name, arguments}}
|
||||
|
||||
mcpStore->>mcpStore: acquireConnection()
|
||||
Note right of mcpStore: activeFlowCount++<br/>Prevent shutdown during execution
|
||||
|
||||
mcpStore->>mcpStore: connection = connections.get(serverName)
|
||||
|
||||
mcpStore->>MCPSvc: callTool(connection, {name, arguments}, signal)
|
||||
activate MCPSvc
|
||||
|
||||
MCPSvc->>MCPSvc: throwIfAborted(signal)
|
||||
MCPSvc->>ExtMCP: callTool(name, arguments)
|
||||
|
||||
alt Tool execution success
|
||||
ExtMCP-->>MCPSvc: ToolCallResult (content, isError)
|
||||
MCPSvc->>MCPSvc: formatToolResult(result)
|
||||
Note right of MCPSvc: Handle text, image (base64),<br/>embedded resource content
|
||||
MCPSvc-->>mcpStore: ToolExecutionResult
|
||||
else Tool execution error
|
||||
ExtMCP-->>MCPSvc: Error
|
||||
MCPSvc-->>mcpStore: throw Error
|
||||
else Aborted
|
||||
MCPSvc-->>mcpStore: throw AbortError
|
||||
end
|
||||
|
||||
deactivate MCPSvc
|
||||
|
||||
mcpStore->>mcpStore: releaseConnection()
|
||||
Note right of mcpStore: activeFlowCount--
|
||||
|
||||
mcpStore-->>UI: ToolExecutionResult
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: <20> RESOURCE ATTACHMENT CONSUMPTION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
chatStore->>mcpStore: consumeResourceAttachmentsAsExtras()
|
||||
activate mcpStore
|
||||
mcpStore->>mcpResStore: getAttachments()
|
||||
mcpResStore-->>mcpStore: MCPResourceAttachment[]
|
||||
mcpStore->>mcpStore: Convert attachments to message extras
|
||||
mcpStore->>mcpResStore: clearAttachments()
|
||||
mcpStore-->>chatStore: MessageExtra[] (for user message)
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: <20>📝 PROMPT OPERATIONS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: getAllPrompts()
|
||||
activate mcpStore
|
||||
|
||||
loop For each connected server with prompts capability
|
||||
mcpStore->>MCPSvc: listPrompts(connection)
|
||||
MCPSvc->>ExtMCP: listPrompts()
|
||||
ExtMCP-->>MCPSvc: Prompt[]
|
||||
MCPSvc-->>mcpStore: prompts
|
||||
end
|
||||
|
||||
mcpStore-->>UI: MCPPromptInfo[] (with serverName)
|
||||
deactivate mcpStore
|
||||
|
||||
UI->>mcpStore: getPrompt(serverName, promptName, args?)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>MCPSvc: getPrompt(connection, name, args)
|
||||
MCPSvc->>ExtMCP: getPrompt({name, arguments})
|
||||
ExtMCP-->>MCPSvc: GetPromptResult (messages)
|
||||
MCPSvc-->>mcpStore: GetPromptResult
|
||||
|
||||
mcpStore-->>UI: GetPromptResult
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 📁 RESOURCE OPERATIONS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpResStore: addAttachment(resourceInfo)
|
||||
activate mcpResStore
|
||||
mcpResStore->>mcpResStore: Create MCPResourceAttachment (loading: true)
|
||||
mcpResStore-->>UI: attachment
|
||||
|
||||
UI->>mcpStore: readResource(serverName, uri)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>MCPSvc: readResource(connection, uri)
|
||||
MCPSvc->>ExtMCP: readResource({uri})
|
||||
ExtMCP-->>MCPSvc: MCPReadResourceResult (contents)
|
||||
MCPSvc-->>mcpStore: contents
|
||||
|
||||
mcpStore-->>UI: MCPResourceContent[]
|
||||
deactivate mcpStore
|
||||
|
||||
UI->>mcpResStore: updateAttachmentContent(attachmentId, content)
|
||||
mcpResStore->>mcpResStore: cacheResourceContent(resource, content)
|
||||
deactivate mcpResStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🔄 AUTO-RECONNECTION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over mcpStore: On WebSocket close or connection error:
|
||||
mcpStore->>mcpStore: autoReconnect(serverName, attempt)
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: Calculate backoff delay
|
||||
Note right of mcpStore: delay = min(30s, 1s * 2^attempt)
|
||||
|
||||
mcpStore->>mcpStore: Wait for delay
|
||||
mcpStore->>mcpStore: reconnectServer(serverName)
|
||||
|
||||
alt Reconnection success
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS)
|
||||
else Max attempts reached
|
||||
mcpStore->>mcpStore: updateHealthCheck(id, ERROR)
|
||||
end
|
||||
deactivate mcpStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,ExtMCP: 🛑 SHUTDOWN
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>mcpStore: shutdown()
|
||||
activate mcpStore
|
||||
|
||||
mcpStore->>mcpStore: Wait for activeFlowCount == 0
|
||||
|
||||
loop For each connection
|
||||
mcpStore->>MCPSvc: disconnect(connection)
|
||||
MCPSvc->>MCPSvc: transport.onclose = undefined
|
||||
MCPSvc->>ExtMCP: close()
|
||||
end
|
||||
|
||||
mcpStore->>mcpStore: connections.clear()
|
||||
mcpStore->>mcpStore: toolsIndex.clear()
|
||||
mcpStore->>mcpStore: _connectedServers = []
|
||||
|
||||
mcpStore->>mcpResStore: clear()
|
||||
deactivate mcpStore
|
||||
```
|
||||
181
tools/server/webui/docs/flows/models-flow.md
Normal file
181
tools/server/webui/docs/flows/models-flow.md
Normal file
@@ -0,0 +1,181 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 ModelsSelector
|
||||
participant Hooks as 🪝 useModelChangeValidation
|
||||
participant modelsStore as 🗄️ modelsStore
|
||||
participant serverStore as 🗄️ serverStore
|
||||
participant convStore as 🗄️ conversationsStore
|
||||
participant ModelsSvc as ⚙️ ModelsService
|
||||
participant PropsSvc as ⚙️ PropsService
|
||||
participant API as 🌐 llama-server
|
||||
|
||||
Note over modelsStore: State:<br/>models: ModelOption[]<br/>routerModels: ApiModelDataEntry[]<br/>selectedModelId, selectedModelName<br/>loading, updating, error<br/>modelLoadingStates (Map)<br/>modelPropsCache (Map)<br/>propsCacheVersion
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🚀 INITIALIZATION (MODEL mode)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>modelsStore: fetch()
|
||||
activate modelsStore
|
||||
modelsStore->>modelsStore: loading = true
|
||||
|
||||
alt serverStore.props not loaded
|
||||
modelsStore->>serverStore: fetch()
|
||||
Note over serverStore: → see server-flow.mmd
|
||||
end
|
||||
|
||||
modelsStore->>ModelsSvc: list()
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: ApiModelListResponse {data: [model]}
|
||||
|
||||
modelsStore->>modelsStore: models = $state(mapped)
|
||||
Note right of modelsStore: Map to ModelOption[]:<br/>{id, name, model, description, capabilities}
|
||||
|
||||
Note over modelsStore: MODEL mode: Get modalities from serverStore.props
|
||||
modelsStore->>modelsStore: modelPropsCache.set(model.id, serverStore.props)
|
||||
modelsStore->>modelsStore: models[0].modalities = props.modalities
|
||||
|
||||
modelsStore->>modelsStore: Auto-select single model
|
||||
Note right of modelsStore: selectedModelId = models[0].id
|
||||
modelsStore->>modelsStore: loading = false
|
||||
deactivate modelsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🚀 INITIALIZATION (ROUTER mode)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>modelsStore: fetch()
|
||||
activate modelsStore
|
||||
modelsStore->>ModelsSvc: list()
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: ApiModelListResponse
|
||||
modelsStore->>modelsStore: models = $state(mapped)
|
||||
deactivate modelsStore
|
||||
|
||||
Note over UI: After models loaded, layout triggers:
|
||||
UI->>modelsStore: fetchRouterModels()
|
||||
activate modelsStore
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: ApiRouterModelsListResponse
|
||||
Note right of API: {data: [{id, status, path, in_cache}]}
|
||||
modelsStore->>modelsStore: routerModels = $state(data)
|
||||
|
||||
modelsStore->>modelsStore: fetchModalitiesForLoadedModels()
|
||||
loop each model where status === "loaded"
|
||||
modelsStore->>PropsSvc: fetchForModel(modelId)
|
||||
PropsSvc->>API: GET /props?model={modelId}
|
||||
API-->>PropsSvc: ApiLlamaCppServerProps
|
||||
modelsStore->>modelsStore: modelPropsCache.set(modelId, props)
|
||||
end
|
||||
modelsStore->>modelsStore: propsCacheVersion++
|
||||
deactivate modelsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🔄 MODEL SELECTION (ROUTER mode)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>Hooks: useModelChangeValidation({getRequiredModalities, onSuccess?, onValidationFailure?})
|
||||
Note over Hooks: Hook configured per-component:<br/>ChatForm: getRequiredModalities = usedModalities<br/>ChatMessage: getRequiredModalities = getModalitiesUpToMessage(msgId)
|
||||
|
||||
UI->>Hooks: handleModelChange(modelId, modelName)
|
||||
activate Hooks
|
||||
Hooks->>Hooks: previousSelectedModelId = modelsStore.selectedModelId
|
||||
Hooks->>modelsStore: isModelLoaded(modelName)?
|
||||
|
||||
alt model NOT loaded
|
||||
Hooks->>modelsStore: loadModel(modelName)
|
||||
Note over modelsStore: → see LOAD MODEL section below
|
||||
end
|
||||
|
||||
Note over Hooks: Always fetch props (from cache or API)
|
||||
Hooks->>modelsStore: fetchModelProps(modelName)
|
||||
modelsStore-->>Hooks: props
|
||||
|
||||
Hooks->>convStore: getRequiredModalities()
|
||||
convStore-->>Hooks: {vision, audio}
|
||||
|
||||
Hooks->>Hooks: Validate: model.modalities ⊇ required?
|
||||
|
||||
alt validation PASSED
|
||||
Hooks->>modelsStore: selectModelById(modelId)
|
||||
Hooks-->>UI: return true
|
||||
else validation FAILED
|
||||
Hooks->>UI: toast.error("Model doesn't support required modalities")
|
||||
alt model was just loaded
|
||||
Hooks->>modelsStore: unloadModel(modelName)
|
||||
end
|
||||
alt onValidationFailure provided
|
||||
Hooks->>modelsStore: selectModelById(previousSelectedModelId)
|
||||
end
|
||||
Hooks-->>UI: return false
|
||||
end
|
||||
deactivate Hooks
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ⬆️ LOAD MODEL (ROUTER mode)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
modelsStore->>modelsStore: loadModel(modelId)
|
||||
activate modelsStore
|
||||
|
||||
alt already loaded
|
||||
modelsStore-->>modelsStore: return (no-op)
|
||||
end
|
||||
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, true)
|
||||
modelsStore->>ModelsSvc: load(modelId)
|
||||
ModelsSvc->>API: POST /models/load {model: modelId}
|
||||
API-->>ModelsSvc: {status: "loading"}
|
||||
|
||||
modelsStore->>modelsStore: pollForModelStatus(modelId, LOADED)
|
||||
loop poll every 500ms (max 60 attempts)
|
||||
modelsStore->>modelsStore: fetchRouterModels()
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: models[]
|
||||
modelsStore->>modelsStore: getModelStatus(modelId)
|
||||
alt status === LOADED
|
||||
Note right of modelsStore: break loop
|
||||
else status === LOADING
|
||||
Note right of modelsStore: wait 500ms, continue
|
||||
end
|
||||
end
|
||||
|
||||
modelsStore->>modelsStore: updateModelModalities(modelId)
|
||||
modelsStore->>PropsSvc: fetchForModel(modelId)
|
||||
PropsSvc->>API: GET /props?model={modelId}
|
||||
API-->>PropsSvc: props with modalities
|
||||
modelsStore->>modelsStore: modelPropsCache.set(modelId, props)
|
||||
modelsStore->>modelsStore: propsCacheVersion++
|
||||
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
|
||||
deactivate modelsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ⬇️ UNLOAD MODEL (ROUTER mode)
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
modelsStore->>modelsStore: unloadModel(modelId)
|
||||
activate modelsStore
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, true)
|
||||
modelsStore->>ModelsSvc: unload(modelId)
|
||||
ModelsSvc->>API: POST /models/unload {model: modelId}
|
||||
|
||||
modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED)
|
||||
loop poll until unloaded
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
end
|
||||
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
|
||||
deactivate modelsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 📊 COMPUTED GETTERS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over modelsStore: Getters:<br/>- selectedModel: ModelOption | null<br/>- loadedModelIds: string[] (from routerModels)<br/>- loadingModelIds: string[] (from modelLoadingStates)<br/>- singleModelName: string | null (MODEL mode only)
|
||||
|
||||
Note over modelsStore: Modality helpers:<br/>- getModelModalities(modelId): {vision, audio}<br/>- modelSupportsVision(modelId): boolean<br/>- modelSupportsAudio(modelId): boolean
|
||||
```
|
||||
76
tools/server/webui/docs/flows/server-flow.md
Normal file
76
tools/server/webui/docs/flows/server-flow.md
Normal file
@@ -0,0 +1,76 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 +layout.svelte
|
||||
participant serverStore as 🗄️ serverStore
|
||||
participant PropsSvc as ⚙️ PropsService
|
||||
participant API as 🌐 llama-server
|
||||
|
||||
Note over serverStore: State:<br/>props: ApiLlamaCppServerProps | null<br/>loading, error<br/>role: ServerRole | null (MODEL | ROUTER)<br/>fetchPromise (deduplication)
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🚀 INITIALIZATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>serverStore: fetch()
|
||||
activate serverStore
|
||||
|
||||
alt fetchPromise exists (already fetching)
|
||||
serverStore-->>UI: return fetchPromise
|
||||
Note right of serverStore: Deduplicate concurrent calls
|
||||
end
|
||||
|
||||
serverStore->>serverStore: loading = true
|
||||
serverStore->>serverStore: fetchPromise = new Promise()
|
||||
|
||||
serverStore->>PropsSvc: fetch()
|
||||
PropsSvc->>API: GET /props
|
||||
API-->>PropsSvc: ApiLlamaCppServerProps
|
||||
Note right of API: {role, model_path, model_alias,<br/>modalities, default_generation_settings, ...}
|
||||
|
||||
PropsSvc-->>serverStore: props
|
||||
serverStore->>serverStore: props = $state(data)
|
||||
|
||||
serverStore->>serverStore: detectRole(props)
|
||||
Note right of serverStore: role = props.role === "router"<br/> ? ServerRole.ROUTER<br/> : ServerRole.MODEL
|
||||
|
||||
serverStore->>serverStore: loading = false
|
||||
serverStore->>serverStore: fetchPromise = null
|
||||
deactivate serverStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 📊 COMPUTED GETTERS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over serverStore: Getters from props:
|
||||
|
||||
rect rgb(240, 255, 240)
|
||||
Note over serverStore: defaultParams<br/>→ props.default_generation_settings.params<br/>(temperature, top_p, top_k, etc.)
|
||||
end
|
||||
|
||||
rect rgb(240, 255, 240)
|
||||
Note over serverStore: contextSize<br/>→ props.default_generation_settings.n_ctx
|
||||
end
|
||||
|
||||
rect rgb(255, 240, 240)
|
||||
Note over serverStore: isRouterMode<br/>→ role === ServerRole.ROUTER
|
||||
end
|
||||
|
||||
rect rgb(255, 240, 240)
|
||||
Note over serverStore: isModelMode<br/>→ role === ServerRole.MODEL
|
||||
end
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: 🔗 RELATIONSHIPS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over serverStore: Used by:
|
||||
Note right of serverStore: - modelsStore: role detection, MODEL mode modalities<br/>- settingsStore: syncWithServerDefaults (defaultParams)<br/>- chatStore: contextSize for processing state<br/>- UI components: isRouterMode for conditional rendering
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,API: ❌ ERROR HANDLING
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over serverStore: getErrorMessage(): string | null<br/>Returns formatted error for UI display
|
||||
|
||||
Note over serverStore: clear(): void<br/>Resets all state (props, error, loading, role)
|
||||
```
|
||||
156
tools/server/webui/docs/flows/settings-flow.md
Normal file
156
tools/server/webui/docs/flows/settings-flow.md
Normal file
@@ -0,0 +1,156 @@
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant UI as 🧩 ChatSettings
|
||||
participant settingsStore as 🗄️ settingsStore
|
||||
participant serverStore as 🗄️ serverStore
|
||||
participant ParamSvc as ⚙️ ParameterSyncService
|
||||
participant LS as 💾 LocalStorage
|
||||
|
||||
Note over settingsStore: State:<br/>config: SettingsConfigType<br/>theme: string ("auto" | "light" | "dark")<br/>isInitialized: boolean<br/>userOverrides: Set<string>
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 🚀 INITIALIZATION
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over settingsStore: Auto-initialized in constructor (browser only)
|
||||
settingsStore->>settingsStore: initialize()
|
||||
activate settingsStore
|
||||
|
||||
settingsStore->>settingsStore: loadConfig()
|
||||
settingsStore->>LS: get("llama-config")
|
||||
LS-->>settingsStore: StoredConfig | null
|
||||
|
||||
alt config exists
|
||||
settingsStore->>settingsStore: Merge with SETTING_CONFIG_DEFAULT
|
||||
Note right of settingsStore: Fill missing keys with defaults
|
||||
else no config
|
||||
settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT
|
||||
end
|
||||
|
||||
settingsStore->>LS: get("llama-userOverrides")
|
||||
LS-->>settingsStore: string[] | null
|
||||
settingsStore->>settingsStore: userOverrides = new Set(data)
|
||||
|
||||
settingsStore->>settingsStore: loadTheme()
|
||||
settingsStore->>LS: get("llama-theme")
|
||||
LS-->>settingsStore: theme | "auto"
|
||||
|
||||
settingsStore->>settingsStore: isInitialized = true
|
||||
deactivate settingsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 🔄 SYNC WITH SERVER DEFAULTS
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over UI: Triggered from +layout.svelte when serverStore.props loaded
|
||||
UI->>settingsStore: syncWithServerDefaults()
|
||||
activate settingsStore
|
||||
|
||||
settingsStore->>serverStore: defaultParams
|
||||
serverStore-->>settingsStore: {temperature, top_p, top_k, ...}
|
||||
|
||||
loop each SYNCABLE_PARAMETER
|
||||
alt key NOT in userOverrides
|
||||
settingsStore->>settingsStore: config[key] = serverDefault[key]
|
||||
Note right of settingsStore: Non-overridden params adopt server default
|
||||
else key in userOverrides
|
||||
Note right of settingsStore: Keep user value, skip server default
|
||||
end
|
||||
end
|
||||
|
||||
alt serverStore.props has webuiSettings
|
||||
settingsStore->>settingsStore: Apply webuiSettings from server
|
||||
Note right of settingsStore: Server-provided UI settings<br/>(e.g. showRawOutputSwitch)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: ⚙️ UPDATE CONFIG
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>settingsStore: updateConfig(key, value)
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config[key] = value
|
||||
|
||||
alt value matches server default for key
|
||||
settingsStore->>settingsStore: userOverrides.delete(key)
|
||||
Note right of settingsStore: Matches server default, remove override
|
||||
else value differs from server default
|
||||
settingsStore->>settingsStore: userOverrides.add(key)
|
||||
Note right of settingsStore: Mark as user-modified (won't be overwritten)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config)
|
||||
settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides])
|
||||
deactivate settingsStore
|
||||
|
||||
UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2})
|
||||
activate settingsStore
|
||||
Note right of settingsStore: Batch update, single save
|
||||
settingsStore->>settingsStore: For each key: config[key] = value
|
||||
settingsStore->>settingsStore: For each key: userOverrides.add(key)
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 🔄 RESET
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>settingsStore: resetConfig()
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT}
|
||||
settingsStore->>settingsStore: userOverrides.clear()
|
||||
Note right of settingsStore: All params reset to defaults<br/>Next syncWithServerDefaults will adopt server values
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
UI->>settingsStore: resetParameterToServerDefault(key)
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: userOverrides.delete(key)
|
||||
settingsStore->>serverStore: defaultParams[key]
|
||||
settingsStore->>settingsStore: config[key] = serverDefault
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 🎨 THEME
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>settingsStore: updateTheme(newTheme)
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: theme = newTheme
|
||||
settingsStore->>settingsStore: saveTheme()
|
||||
settingsStore->>LS: set("llama-theme", theme)
|
||||
deactivate settingsStore
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 📊 PARAMETER INFO
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
UI->>settingsStore: getParameterInfo(key)
|
||||
settingsStore->>ParamSvc: getParameterInfo(key, config, serverDefaults, userOverrides)
|
||||
ParamSvc-->>settingsStore: ParameterInfo
|
||||
Note right of ParamSvc: {<br/> currentValue,<br/> serverDefault,<br/> isUserOverride: boolean,<br/> canSync: boolean,<br/> isDifferentFromServer: boolean<br/>}
|
||||
|
||||
UI->>settingsStore: getParameterDiff()
|
||||
settingsStore->>ParamSvc: createParameterDiff(config, serverDefaults, userOverrides)
|
||||
ParamSvc-->>settingsStore: ParameterDiff[]
|
||||
Note right of ParamSvc: Array of parameters where user != server
|
||||
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
Note over UI,LS: 📋 CONFIG CATEGORIES
|
||||
%% ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Note over settingsStore: Syncable with server (from /props):
|
||||
rect rgb(240, 255, 240)
|
||||
Note over settingsStore: temperature, top_p, top_k, min_p<br/>repeat_penalty, presence_penalty, frequency_penalty<br/>dynatemp_range, dynatemp_exponent<br/>typ_p, xtc_probability, xtc_threshold<br/>dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n
|
||||
end
|
||||
|
||||
Note over settingsStore: UI-only (not synced):
|
||||
rect rgb(255, 240, 240)
|
||||
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch
|
||||
end
|
||||
```
|
||||
51
tools/server/webui/eslint.config.js
Normal file
51
tools/server/webui/eslint.config.js
Normal file
@@ -0,0 +1,51 @@
|
||||
// For more info, see https://github.com/storybookjs/eslint-plugin-storybook#configuration-flat-config-format
|
||||
import storybook from 'eslint-plugin-storybook';
|
||||
|
||||
import prettier from 'eslint-config-prettier';
|
||||
import { includeIgnoreFile } from '@eslint/compat';
|
||||
import js from '@eslint/js';
|
||||
import svelte from 'eslint-plugin-svelte';
|
||||
import globals from 'globals';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import ts from 'typescript-eslint';
|
||||
import svelteConfig from './svelte.config.js';
|
||||
|
||||
const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url));
|
||||
|
||||
export default ts.config(
|
||||
includeIgnoreFile(gitignorePath),
|
||||
js.configs.recommended,
|
||||
...ts.configs.recommended,
|
||||
...svelte.configs.recommended,
|
||||
prettier,
|
||||
...svelte.configs.prettier,
|
||||
{
|
||||
languageOptions: {
|
||||
globals: { ...globals.browser, ...globals.node }
|
||||
},
|
||||
rules: {
|
||||
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
|
||||
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
|
||||
'no-undef': 'off',
|
||||
'svelte/no-at-html-tags': 'off',
|
||||
// This app uses hash-based routing (#/) where resolve() from $app/paths does not apply
|
||||
'svelte/no-navigation-without-resolve': 'off'
|
||||
}
|
||||
},
|
||||
{
|
||||
files: ['**/*.svelte', '**/*.svelte.ts', '**/*.svelte.js'],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
projectService: true,
|
||||
extraFileExtensions: ['.svelte'],
|
||||
parser: ts.parser,
|
||||
svelteConfig
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
// Exclude Storybook files from main ESLint rules
|
||||
ignores: ['.storybook/**/*']
|
||||
},
|
||||
storybook.configs['flat/recommended']
|
||||
);
|
||||
10704
tools/server/webui/package-lock.json
generated
Normal file
10704
tools/server/webui/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
96
tools/server/webui/package.json
Normal file
96
tools/server/webui/package.json
Normal file
@@ -0,0 +1,96 @@
|
||||
{
|
||||
"name": "llama-server-webui",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "bash scripts/dev.sh",
|
||||
"build": "vite build && ./scripts/post-build.sh",
|
||||
"preview": "vite preview",
|
||||
"prepare": "svelte-kit sync || echo ''",
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
||||
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
|
||||
"reset": "rm -rf .svelte-kit node_modules",
|
||||
"format": "prettier --write .",
|
||||
"lint": "prettier --check . && eslint .",
|
||||
"test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:unit -- --run && npm run test:e2e",
|
||||
"test:e2e": "playwright test",
|
||||
"test:client": "vitest --project=client",
|
||||
"test:unit": "vitest --project=unit",
|
||||
"test:ui": "vitest --project=ui",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"cleanup": "rm -rf .svelte-kit build node_modules test-results"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^5.0.0",
|
||||
"@eslint/compat": "^1.2.5",
|
||||
"@eslint/js": "^9.18.0",
|
||||
"@internationalized/date": "^3.10.1",
|
||||
"@lucide/svelte": "^0.515.0",
|
||||
"@playwright/test": "^1.49.1",
|
||||
"@storybook/addon-a11y": "^10.2.4",
|
||||
"@storybook/addon-docs": "^10.2.4",
|
||||
"@storybook/addon-svelte-csf": "^5.0.10",
|
||||
"@storybook/addon-vitest": "^10.2.4",
|
||||
"@storybook/sveltekit": "^10.2.4",
|
||||
"@sveltejs/adapter-static": "^3.0.10",
|
||||
"@sveltejs/kit": "^2.48.4",
|
||||
"@sveltejs/vite-plugin-svelte": "^6.2.1",
|
||||
"@tailwindcss/forms": "^0.5.9",
|
||||
"@tailwindcss/typography": "^0.5.15",
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
"@types/node": "^24",
|
||||
"@vitest/browser": "^3.2.3",
|
||||
"@vitest/coverage-v8": "^3.2.3",
|
||||
"bits-ui": "^2.14.4",
|
||||
"clsx": "^2.1.1",
|
||||
"dexie": "^4.0.11",
|
||||
"eslint": "^9.18.0",
|
||||
"eslint-config-prettier": "^10.0.1",
|
||||
"eslint-plugin-storybook": "^10.2.4",
|
||||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"globals": "^16.0.0",
|
||||
"http-server": "^14.1.1",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.56.1",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"sass": "^1.93.3",
|
||||
"storybook": "^10.2.4",
|
||||
"svelte": "^5.38.2",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-variants": "^3.2.2",
|
||||
"tailwindcss": "^4.0.0",
|
||||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.2.2",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
"vitest": "^3.2.3",
|
||||
"vitest-browser-svelte": "^0.1.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.25.1",
|
||||
"highlight.js": "^11.11.1",
|
||||
"mode-watcher": "^1.1.0",
|
||||
"pdfjs-dist": "^5.4.54",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark": "^15.0.1",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-html": "^16.0.1",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-sonner": "^1.0.5",
|
||||
"unist-util-visit": "^5.0.0",
|
||||
"zod": "^4.2.1"
|
||||
}
|
||||
}
|
||||
11
tools/server/webui/playwright.config.ts
Normal file
11
tools/server/webui/playwright.config.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { defineConfig } from '@playwright/test';
|
||||
|
||||
export default defineConfig({
|
||||
webServer: {
|
||||
command: 'npm run build && http-server ../public -p 8181',
|
||||
port: 8181,
|
||||
timeout: 120000,
|
||||
reuseExistingServer: false
|
||||
},
|
||||
testDir: 'tests/e2e'
|
||||
});
|
||||
57
tools/server/webui/scripts/dev.sh
Normal file
57
tools/server/webui/scripts/dev.sh
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Development script for llama.cpp webui
|
||||
#
|
||||
# This script starts the webui development servers (Storybook and Vite).
|
||||
# Note: You need to start llama-server separately.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/dev.sh
|
||||
# npm run dev
|
||||
|
||||
cd ../../../
|
||||
|
||||
# Check and install git hooks if missing
|
||||
check_and_install_hooks() {
|
||||
local hooks_missing=false
|
||||
|
||||
# Check for required hooks
|
||||
if [ ! -f ".git/hooks/pre-commit" ] || [ ! -f ".git/hooks/pre-push" ] || [ ! -f ".git/hooks/post-push" ]; then
|
||||
hooks_missing=true
|
||||
fi
|
||||
|
||||
if [ "$hooks_missing" = true ]; then
|
||||
echo "🔧 Git hooks missing, installing them..."
|
||||
cd tools/server/webui
|
||||
if bash scripts/install-git-hooks.sh; then
|
||||
echo "✅ Git hooks installed successfully"
|
||||
else
|
||||
echo "⚠️ Failed to install git hooks, continuing anyway..."
|
||||
fi
|
||||
cd ../../../
|
||||
else
|
||||
echo "✅ Git hooks already installed"
|
||||
fi
|
||||
}
|
||||
|
||||
# Install git hooks if needed
|
||||
check_and_install_hooks
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "🧹 Cleaning up..."
|
||||
exit
|
||||
}
|
||||
|
||||
# Set up signal handlers
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
echo "🚀 Starting development servers..."
|
||||
echo "📝 Note: Make sure to start llama-server separately if needed"
|
||||
cd tools/server/webui
|
||||
# Use --insecure-http-parser to handle malformed HTTP responses from llama-server
|
||||
# (some responses have both Content-Length and Transfer-Encoding headers)
|
||||
storybook dev -p 6006 --ci & NODE_OPTIONS="--insecure-http-parser" vite dev --host 0.0.0.0 &
|
||||
|
||||
# Wait for all background processes
|
||||
wait
|
||||
82
tools/server/webui/scripts/install-git-hooks.sh
Executable file
82
tools/server/webui/scripts/install-git-hooks.sh
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to install pre-commit hook for webui
|
||||
# Pre-commit: formats, checks, builds, and stages build output
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
PRE_COMMIT_HOOK="$REPO_ROOT/.git/hooks/pre-commit"
|
||||
|
||||
echo "Installing pre-commit hook for webui..."
|
||||
|
||||
# Create the pre-commit hook
|
||||
cat > "$PRE_COMMIT_HOOK" << 'EOF'
|
||||
#!/bin/bash
|
||||
|
||||
# Check if there are any changes in the webui directory
|
||||
if git diff --cached --name-only | grep -q "^tools/server/webui/"; then
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
cd "$REPO_ROOT/tools/server/webui"
|
||||
|
||||
# Check if package.json exists
|
||||
if [ ! -f "package.json" ]; then
|
||||
echo "Error: package.json not found in tools/server/webui"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Formatting and checking webui code..."
|
||||
|
||||
# Run the format command
|
||||
npm run format
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run format failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the lint command
|
||||
npm run lint
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run lint failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the check command
|
||||
npm run check
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run check failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Webui code formatted and checked successfully"
|
||||
|
||||
# Build the webui
|
||||
echo "Building webui..."
|
||||
npm run build
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "❌ npm run build failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Stage the build output alongside the source changes
|
||||
cd "$REPO_ROOT"
|
||||
git add tools/server/public/
|
||||
|
||||
echo "✅ Webui built and build output staged"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
EOF
|
||||
|
||||
# Make hook executable
|
||||
chmod +x "$PRE_COMMIT_HOOK"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ Git hook installed successfully!"
|
||||
echo " Pre-commit: $PRE_COMMIT_HOOK"
|
||||
echo ""
|
||||
echo "The hook will automatically:"
|
||||
echo " • Format, lint and check webui code before commits"
|
||||
echo " • Build webui and stage tools/server/public/ into the same commit"
|
||||
else
|
||||
echo "❌ Failed to make hook executable"
|
||||
exit 1
|
||||
fi
|
||||
3
tools/server/webui/scripts/post-build.sh
Executable file
3
tools/server/webui/scripts/post-build.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
rm -rf ../public/_app;
|
||||
rm ../public/favicon.svg;
|
||||
rm -f ../public/index.html.gz; # deprecated, but may still be generated by older versions of the build process
|
||||
84
tools/server/webui/scripts/vite-plugin-llama-cpp-build.ts
Normal file
84
tools/server/webui/scripts/vite-plugin-llama-cpp-build.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { readFileSync, writeFileSync, existsSync, readdirSync, copyFileSync } from 'fs';
|
||||
import { resolve } from 'path';
|
||||
import type { Plugin } from 'vite';
|
||||
|
||||
const GUIDE_FOR_FRONTEND = `
|
||||
<!--
|
||||
This is a static build of the frontend.
|
||||
It is automatically generated by the build process.
|
||||
Do not edit this file directly.
|
||||
To make changes, refer to the "Web UI" section in the README.
|
||||
-->
|
||||
`.trim();
|
||||
|
||||
export function llamaCppBuildPlugin(): Plugin {
|
||||
return {
|
||||
name: 'llamacpp:build',
|
||||
apply: 'build',
|
||||
closeBundle() {
|
||||
// Ensure the SvelteKit adapter has finished writing to ../public
|
||||
setTimeout(() => {
|
||||
try {
|
||||
const indexPath = resolve('../public/index.html');
|
||||
if (!existsSync(indexPath)) return;
|
||||
|
||||
let content = readFileSync(indexPath, 'utf-8');
|
||||
|
||||
const faviconPath = resolve('static/favicon.svg');
|
||||
|
||||
if (existsSync(faviconPath)) {
|
||||
const faviconContent = readFileSync(faviconPath, 'utf-8');
|
||||
const faviconBase64 = Buffer.from(faviconContent).toString('base64');
|
||||
const faviconDataUrl = `data:image/svg+xml;base64,${faviconBase64}`;
|
||||
|
||||
content = content.replace(/href="[^"]*favicon\.svg"/g, `href="${faviconDataUrl}"`);
|
||||
|
||||
console.log('✓ Inlined favicon.svg as base64 data URL');
|
||||
}
|
||||
|
||||
content = content.replace(/\r/g, '');
|
||||
content = GUIDE_FOR_FRONTEND + '\n' + content;
|
||||
content = content.replace(/\/_app\/immutable\/bundle\.[^"]+\.js/g, './bundle.js');
|
||||
content = content.replace(
|
||||
/\/_app\/immutable\/assets\/bundle\.[^"]+\.css/g,
|
||||
'./bundle.css'
|
||||
);
|
||||
content = content.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
|
||||
|
||||
writeFileSync(indexPath, content, 'utf-8');
|
||||
console.log('✓ Updated index.html');
|
||||
|
||||
// Copy bundle.*.js -> ../public/bundle.js
|
||||
const immutableDir = resolve('../public/_app/immutable');
|
||||
const bundleDir = resolve('../public/_app/immutable/assets');
|
||||
|
||||
if (existsSync(immutableDir)) {
|
||||
const jsFiles = readdirSync(immutableDir).filter((f) => f.match(/^bundle\..+\.js$/));
|
||||
|
||||
if (jsFiles.length > 0) {
|
||||
copyFileSync(resolve(immutableDir, jsFiles[0]), resolve('../public/bundle.js'));
|
||||
// Normalize __sveltekit_<hash> to __sveltekit__ in bundle.js
|
||||
const bundleJsPath = resolve('../public/bundle.js');
|
||||
let bundleJs = readFileSync(bundleJsPath, 'utf-8');
|
||||
bundleJs = bundleJs.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
|
||||
writeFileSync(bundleJsPath, bundleJs, 'utf-8');
|
||||
console.log(`✓ Copied ${jsFiles[0]} -> bundle.js`);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy bundle.*.css -> ../public/bundle.css
|
||||
if (existsSync(bundleDir)) {
|
||||
const cssFiles = readdirSync(bundleDir).filter((f) => f.match(/^bundle\..+\.css$/));
|
||||
|
||||
if (cssFiles.length > 0) {
|
||||
copyFileSync(resolve(bundleDir, cssFiles[0]), resolve('../public/bundle.css'));
|
||||
console.log(`✓ Copied ${cssFiles[0]} -> bundle.css`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to update index.html:', error);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
};
|
||||
}
|
||||
186
tools/server/webui/src/app.css
Normal file
186
tools/server/webui/src/app.css
Normal file
@@ -0,0 +1,186 @@
|
||||
@import 'tailwindcss';
|
||||
|
||||
@import 'tw-animate-css';
|
||||
|
||||
@custom-variant dark (&:is(.dark *));
|
||||
|
||||
:root {
|
||||
--radius: 0.625rem;
|
||||
--background: oklch(1 0 0);
|
||||
--foreground: oklch(0.145 0 0);
|
||||
--card: oklch(1 0 0);
|
||||
--card-foreground: oklch(0.145 0 0);
|
||||
--popover: oklch(1 0 0);
|
||||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.205 0 0);
|
||||
--primary-foreground: oklch(0.985 0 0);
|
||||
--secondary: oklch(0.95 0 0);
|
||||
--secondary-foreground: oklch(0.205 0 0);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.95 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.577 0.245 27.325);
|
||||
--border: oklch(0.875 0 0);
|
||||
--input: oklch(0.92 0 0);
|
||||
--ring: oklch(0.708 0 0);
|
||||
--chart-1: oklch(0.646 0.222 41.116);
|
||||
--chart-2: oklch(0.6 0.118 184.704);
|
||||
--chart-3: oklch(0.398 0.07 227.392);
|
||||
--chart-4: oklch(0.828 0.189 84.429);
|
||||
--chart-5: oklch(0.769 0.188 70.08);
|
||||
--sidebar: oklch(0.985 0 0);
|
||||
--sidebar-foreground: oklch(0.145 0 0);
|
||||
--sidebar-primary: oklch(0.205 0 0);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.97 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.985 0 0);
|
||||
--code-foreground: oklch(0.145 0 0);
|
||||
--layer-popover: 1000000;
|
||||
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@media (min-width: 640px) {
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
}
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: oklch(0.16 0 0);
|
||||
--foreground: oklch(0.985 0 0);
|
||||
--card: oklch(0.205 0 0);
|
||||
--card-foreground: oklch(0.985 0 0);
|
||||
--popover: oklch(0.205 0 0);
|
||||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.922 0 0);
|
||||
--primary-foreground: oklch(0.205 0 0);
|
||||
--secondary: oklch(0.29 0 0);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
--accent: oklch(0.269 0 0);
|
||||
--accent-foreground: oklch(0.985 0 0);
|
||||
--destructive: oklch(0.704 0.191 22.216);
|
||||
--border: oklch(1 0 0 / 30%);
|
||||
--input: oklch(1 0 0 / 30%);
|
||||
--ring: oklch(0.556 0 0);
|
||||
--chart-1: oklch(0.488 0.243 264.376);
|
||||
--chart-2: oklch(0.696 0.17 162.48);
|
||||
--chart-3: oklch(0.769 0.188 70.08);
|
||||
--chart-4: oklch(0.627 0.265 303.9);
|
||||
--chart-5: oklch(0.645 0.246 16.439);
|
||||
--sidebar: oklch(0.2 0 0);
|
||||
--sidebar-foreground: oklch(0.985 0 0);
|
||||
--sidebar-primary: oklch(0.488 0.243 264.376);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.269 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.985 0 0);
|
||||
--sidebar-border: oklch(1 0 0 / 10%);
|
||||
--sidebar-ring: oklch(0.556 0 0);
|
||||
--code-background: oklch(0.225 0 0);
|
||||
--code-foreground: oklch(0.875 0 0);
|
||||
}
|
||||
|
||||
@theme inline {
|
||||
--radius-sm: calc(var(--radius) - 4px);
|
||||
--radius-md: calc(var(--radius) - 2px);
|
||||
--radius-lg: var(--radius);
|
||||
--radius-xl: calc(var(--radius) + 4px);
|
||||
--color-background: var(--background);
|
||||
--color-foreground: var(--foreground);
|
||||
--color-card: var(--card);
|
||||
--color-card-foreground: var(--card-foreground);
|
||||
--color-popover: var(--popover);
|
||||
--color-popover-foreground: var(--popover-foreground);
|
||||
--color-primary: var(--primary);
|
||||
--color-primary-foreground: var(--primary-foreground);
|
||||
--color-secondary: var(--secondary);
|
||||
--color-secondary-foreground: var(--secondary-foreground);
|
||||
--color-muted: var(--muted);
|
||||
--color-muted-foreground: var(--muted-foreground);
|
||||
--color-accent: var(--accent);
|
||||
--color-accent-foreground: var(--accent-foreground);
|
||||
--color-destructive: var(--destructive);
|
||||
--color-border: var(--border);
|
||||
--color-input: var(--input);
|
||||
--color-ring: var(--ring);
|
||||
--color-chart-1: var(--chart-1);
|
||||
--color-chart-2: var(--chart-2);
|
||||
--color-chart-3: var(--chart-3);
|
||||
--color-chart-4: var(--chart-4);
|
||||
--color-chart-5: var(--chart-5);
|
||||
--color-sidebar: var(--sidebar);
|
||||
--color-sidebar-foreground: var(--sidebar-foreground);
|
||||
--color-sidebar-primary: var(--sidebar-primary);
|
||||
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
|
||||
--color-sidebar-accent: var(--sidebar-accent);
|
||||
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
|
||||
--color-sidebar-border: var(--sidebar-border);
|
||||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
* {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: transparent transparent;
|
||||
transition: scrollbar-color 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover {
|
||||
scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
transition: background 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover::-webkit-scrollbar-thumb {
|
||||
background: hsl(var(--muted-foreground) / 0.3);
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb:hover {
|
||||
background: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
.scrollbar-hide {
|
||||
/* Hide scrollbar for Chrome, Safari and Opera */
|
||||
&::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
/* Hide scrollbar for IE, Edge and Firefox */
|
||||
-ms-overflow-style: none;
|
||||
scrollbar-width: none;
|
||||
}
|
||||
}
|
||||
131
tools/server/webui/src/app.d.ts
vendored
Normal file
131
tools/server/webui/src/app.d.ts
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
// See https://svelte.dev/docs/kit/types#app.d.ts
|
||||
// for information about these interfaces
|
||||
|
||||
// Import chat types from dedicated module
|
||||
|
||||
import type {
|
||||
// API types
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatCompletionToolCall,
|
||||
ApiChatCompletionToolCallDelta,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
ApiRouterModelsLoadRequest,
|
||||
ApiRouterModelsLoadResponse,
|
||||
ApiRouterModelsStatusRequest,
|
||||
ApiRouterModelsStatusResponse,
|
||||
ApiRouterModelsListResponse,
|
||||
ApiRouterModelsUnloadRequest,
|
||||
ApiRouterModelsUnloadResponse,
|
||||
// Chat types
|
||||
ChatAttachmentDisplayItem,
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageTimings,
|
||||
// Database types
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
DatabaseMessageExtraLegacyContext,
|
||||
ExportedConversation,
|
||||
ExportedConversations,
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType
|
||||
} from '$lib/types';
|
||||
|
||||
import { ServerRole, ServerModelStatus, ModelModality } from '$lib/enums';
|
||||
|
||||
declare global {
|
||||
// namespace App {
|
||||
// interface Error {}
|
||||
// interface Locals {}
|
||||
// interface PageData {}
|
||||
// interface PageState {}
|
||||
// interface Platform {}
|
||||
// }
|
||||
|
||||
export {
|
||||
// API types
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatCompletionToolCall,
|
||||
ApiChatCompletionToolCallDelta,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
ApiRouterModelsLoadRequest,
|
||||
ApiRouterModelsLoadResponse,
|
||||
ApiRouterModelsStatusRequest,
|
||||
ApiRouterModelsStatusResponse,
|
||||
ApiRouterModelsListResponse,
|
||||
ApiRouterModelsUnloadRequest,
|
||||
ApiRouterModelsUnloadResponse,
|
||||
// Chat types
|
||||
ChatAttachmentDisplayItem,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessageTimings,
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
// Database types
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
DatabaseMessageExtraLegacyContext,
|
||||
ExportedConversation,
|
||||
ExportedConversations,
|
||||
// Enum types
|
||||
ModelModality,
|
||||
ServerRole,
|
||||
ServerModelStatus,
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType
|
||||
};
|
||||
}
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
idxThemeStyle?: number;
|
||||
idxCodeBlock?: number;
|
||||
}
|
||||
}
|
||||
12
tools/server/webui/src/app.html
Normal file
12
tools/server/webui/src/app.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%sveltekit.assets%/favicon.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
%sveltekit.head%
|
||||
</head>
|
||||
<body data-sveltekit-preload-data="hover">
|
||||
<div style="display: contents">%sveltekit.body%</div>
|
||||
</body>
|
||||
</html>
|
||||
47
tools/server/webui/src/lib/actions/fade-in-view.svelte.ts
Normal file
47
tools/server/webui/src/lib/actions/fade-in-view.svelte.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { isElementInViewport } from '$lib/utils/viewport';
|
||||
|
||||
/**
|
||||
* Svelte action that fades in an element when it enters the viewport.
|
||||
* Uses IntersectionObserver for efficient viewport detection.
|
||||
*
|
||||
* If skipIfVisible is set and the element is already visible in the viewport
|
||||
* when the action attaches (e.g. a markdown block promoted from unstable
|
||||
* during streaming), the fade is skipped entirely to avoid a flash.
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible && isElementInViewport(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
node.style.opacity = '0';
|
||||
node.style.transform = `translateY(${y}px)`;
|
||||
node.style.transition = `opacity ${duration}ms ease-out, transform ${duration}ms ease-out`;
|
||||
|
||||
$effect(() => {
|
||||
const observer = new IntersectionObserver(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
},
|
||||
{ threshold: 0.05 }
|
||||
);
|
||||
|
||||
observer.observe(node);
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
});
|
||||
}
|
||||
11
tools/server/webui/src/lib/components/app/SKILL.md
Normal file
11
tools/server/webui/src/lib/components/app/SKILL.md
Normal file
@@ -0,0 +1,11 @@
|
||||
---
|
||||
name: app
|
||||
description: Opinionated app components building on top of ./ui primitives
|
||||
---
|
||||
|
||||
- Can include business logic and state management
|
||||
- Can include data fetching and caching logic
|
||||
- Should use original spelling for HTML-native events and `camelCase` for custom events
|
||||
- Props and markup attributes should be listed alphabetically
|
||||
- Use JS Objects and Arrays for CSS classes and styles when they are dynamic
|
||||
- Whenever there can be repetition in the component's markup, if it's too small to be decoupled as a separate component — use Svelte 5's `{#snippet}` + `{@render}`
|
||||
@@ -0,0 +1,60 @@
|
||||
<script lang="ts">
|
||||
import { Button, type ButtonVariant, type ButtonSize } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import type { Component } from 'svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
ariaLabel?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
icon: Component;
|
||||
iconSize?: string;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
size?: ButtonSize;
|
||||
stopPropagationOnClick?: boolean;
|
||||
tooltip: string;
|
||||
variant?: ButtonVariant;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
|
||||
let {
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
iconSize = 'h-3 w-3',
|
||||
tooltipSide = TooltipSide.TOP,
|
||||
stopPropagationOnClick = false,
|
||||
onclick,
|
||||
ariaLabel
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
@@ -0,0 +1,17 @@
|
||||
<script lang="ts">
|
||||
import { Copy } from '@lucide/svelte';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import ActionIcon from './ActionIcon.svelte';
|
||||
|
||||
export let ariaLabel: string = 'Copy to clipboard';
|
||||
export let canCopy: boolean = true;
|
||||
export let text: string;
|
||||
</script>
|
||||
|
||||
<ActionIcon
|
||||
icon={Copy}
|
||||
tooltip={ariaLabel}
|
||||
iconSize="h-4 w-4"
|
||||
disabled={!canCopy}
|
||||
onclick={() => canCopy && copyToClipboard(text)}
|
||||
/>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user