ollama source for Momentry Core verification
This commit is contained in:
955
middleware/anthropic.go
Normal file
955
middleware/anthropic.go
Normal file
@@ -0,0 +1,955 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &errData); err != nil {
|
||||
// If the error response isn't valid JSON, use the raw bytes as the
|
||||
// error message rather than surfacing a confusing JSON parse error.
|
||||
errData.Error = string(data)
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
return writeSSE(w.ResponseWriter, eventType, data)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
|
||||
// executes the search, re-invokes the model with results, and assembles the
|
||||
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
|
||||
type WebSearchAnthropicWriter struct {
|
||||
BaseWriter
|
||||
newLoopContext func() (context.Context, context.CancelFunc)
|
||||
inner *AnthropicWriter
|
||||
req anthropic.MessagesRequest // original Anthropic request
|
||||
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
|
||||
stream bool
|
||||
|
||||
estimatedInputTokens int
|
||||
|
||||
terminalSent bool
|
||||
|
||||
observedPromptEvalCount int
|
||||
observedEvalCount int
|
||||
|
||||
loopInFlight bool
|
||||
loopBaseInputTok int
|
||||
loopBaseOutputTok int
|
||||
loopResultCh chan webSearchLoopResult
|
||||
|
||||
streamMessageStarted bool
|
||||
streamHasOpenBlock bool
|
||||
streamOpenBlockIndex int
|
||||
streamNextIndex int
|
||||
}
|
||||
|
||||
const maxWebSearchLoops = 3
|
||||
|
||||
type webSearchLoopResult struct {
|
||||
response anthropic.MessagesResponse
|
||||
loopErr *webSearchLoopError
|
||||
}
|
||||
|
||||
type webSearchLoopError struct {
|
||||
code string
|
||||
query string
|
||||
usage anthropic.Usage
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *webSearchLoopError) Error() string {
|
||||
if e.err == nil {
|
||||
return e.code
|
||||
}
|
||||
return fmt.Sprintf("%s: %v", e.code, e.err)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
|
||||
if w.terminalSent {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.inner.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.recordObservedUsage(chatResponse.Metrics)
|
||||
|
||||
if w.stream && w.loopInFlight {
|
||||
if !chatResponse.Done {
|
||||
return len(data), nil
|
||||
}
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
|
||||
logutil.Trace("anthropic middleware: upstream chunk",
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"web_search", hasWebSearch,
|
||||
"other_tools", hasOtherTools,
|
||||
)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
if w.stream {
|
||||
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
return w.inner.writeResponse(data)
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
// Let the original generation continue to completion while web search runs in parallel.
|
||||
logutil.Trace("anthropic middleware: starting async web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
)
|
||||
w.startLoopWorker(chatResponse, webSearchCall)
|
||||
if chatResponse.Done {
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
loopCtx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
|
||||
}
|
||||
logutil.Trace("anthropic middleware: starting sync web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"usage", initialUsage,
|
||||
)
|
||||
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
|
||||
if loopErr != nil {
|
||||
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
|
||||
}
|
||||
|
||||
if err := w.writeTerminalResponse(response); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
|
||||
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
|
||||
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
|
||||
|
||||
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
|
||||
usage := initialUsage
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
|
||||
"model", w.req.Model,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
"messages", len(followUpMessages),
|
||||
"tools", len(followUpTools),
|
||||
"max_loops", maxWebSearchLoops,
|
||||
)
|
||||
|
||||
currentResponse := initialResponse
|
||||
currentToolCall := initialToolCall
|
||||
|
||||
var serverContent []anthropic.ContentBlock
|
||||
|
||||
for loop := 1; loop <= maxWebSearchLoops; loop++ {
|
||||
query := extractQueryFromToolCall(¤tToolCall)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
|
||||
"loop", loop,
|
||||
"query", anthropic.TraceTruncateString(query),
|
||||
"messages", len(followUpMessages),
|
||||
)
|
||||
if query == "" {
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "invalid_request",
|
||||
query: "",
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "unavailable",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
|
||||
"loop", loop,
|
||||
"results", len(searchResp.Results),
|
||||
)
|
||||
|
||||
toolUseID := loopServerToolUseID(w.inner.id, loop)
|
||||
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: queryArgs(query),
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: searchResults,
|
||||
},
|
||||
)
|
||||
|
||||
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
|
||||
toolResultMsg := api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
|
||||
ToolCallID: currentToolCall.ID,
|
||||
}
|
||||
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
|
||||
|
||||
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "api_error",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup response",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceChatResponse(followUpResponse),
|
||||
)
|
||||
|
||||
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
|
||||
usage.OutputTokens += followUpResponse.Metrics.EvalCount
|
||||
|
||||
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceMessagesResponse(finalResponse),
|
||||
)
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
currentResponse = followUpResponse
|
||||
currentToolCall = nextToolCall
|
||||
}
|
||||
|
||||
maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
|
||||
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: maxLoopToolUseID,
|
||||
Name: "web_search",
|
||||
Input: queryArgs(maxLoopQuery),
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: maxLoopToolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
maxResponse := anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: serverContent,
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
|
||||
"resp", anthropic.TraceMessagesResponse(maxResponse),
|
||||
)
|
||||
return maxResponse, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
|
||||
if w.loopInFlight {
|
||||
return
|
||||
}
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
|
||||
}
|
||||
w.loopBaseInputTok = initialUsage.InputTokens
|
||||
w.loopBaseOutputTok = initialUsage.OutputTokens
|
||||
w.loopResultCh = make(chan webSearchLoopResult, 1)
|
||||
w.loopInFlight = true
|
||||
logutil.Trace("anthropic middleware: loop worker started",
|
||||
"usage", initialUsage,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
)
|
||||
|
||||
go func() {
|
||||
ctx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
|
||||
w.loopResultCh <- webSearchLoopResult{
|
||||
response: response,
|
||||
loopErr: loopErr,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
|
||||
if w.loopResultCh == nil {
|
||||
return w.sendError("api_error", "", w.currentObservedUsage())
|
||||
}
|
||||
|
||||
result := <-w.loopResultCh
|
||||
w.loopResultCh = nil
|
||||
w.loopInFlight = false
|
||||
if result.loopErr != nil {
|
||||
logutil.Trace("anthropic middleware: loop worker returned error",
|
||||
"code", result.loopErr.code,
|
||||
"query", result.loopErr.query,
|
||||
"usage", result.loopErr.usage,
|
||||
"error", result.loopErr.err,
|
||||
)
|
||||
usage := result.loopErr.usage
|
||||
w.applyObservedUsageDeltaToUsage(&usage)
|
||||
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
|
||||
}
|
||||
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
|
||||
|
||||
w.applyObservedUsageDelta(&result.response)
|
||||
return w.writeTerminalResponse(result.response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
|
||||
w.applyObservedUsageDeltaToUsage(&response.Usage)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
|
||||
if metrics.PromptEvalCount > w.observedPromptEvalCount {
|
||||
w.observedPromptEvalCount = metrics.PromptEvalCount
|
||||
}
|
||||
if metrics.EvalCount > w.observedEvalCount {
|
||||
w.observedEvalCount = metrics.EvalCount
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
|
||||
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
|
||||
usage.InputTokens += deltaIn
|
||||
}
|
||||
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
|
||||
usage.OutputTokens += deltaOut
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
|
||||
return anthropic.Usage{
|
||||
InputTokens: w.observedPromptEvalCount,
|
||||
OutputTokens: w.observedEvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
|
||||
if w.newLoopContext != nil {
|
||||
return w.newLoopContext()
|
||||
}
|
||||
return context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
|
||||
|
||||
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
|
||||
content = append(content, serverContent...)
|
||||
content = append(content, converted.Content...)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: content,
|
||||
StopReason: converted.StopReason,
|
||||
StopSequence: converted.StopSequence,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{webSearchCall},
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
assistantMsg.Content = response.Message.Content
|
||||
}
|
||||
if response.Message.Thinking != "" {
|
||||
assistantMsg.Thinking = response.Message.Thinking
|
||||
}
|
||||
return assistantMsg
|
||||
}
|
||||
|
||||
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
|
||||
var resultText strings.Builder
|
||||
for _, r := range results {
|
||||
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
|
||||
if r.Content != "" {
|
||||
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
return resultText.String()
|
||||
}
|
||||
|
||||
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
|
||||
var webSearchCall api.ToolCall
|
||||
hasWebSearch := false
|
||||
hasOtherTools := false
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == "web_search" {
|
||||
if !hasWebSearch {
|
||||
webSearchCall = toolCall
|
||||
hasWebSearch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
hasOtherTools = true
|
||||
}
|
||||
|
||||
return webSearchCall, hasWebSearch, hasOtherTools
|
||||
}
|
||||
|
||||
func loopServerToolUseID(messageID string, loop int) string {
|
||||
base := serverToolUseID(messageID)
|
||||
if loop <= 1 {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", base, loop)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
|
||||
streaming := false
|
||||
followUp := api.ChatRequest{
|
||||
Model: w.chatReq.Model,
|
||||
Messages: messages,
|
||||
Stream: &streaming,
|
||||
Tools: tools,
|
||||
Options: w.chatReq.Options,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(followUp)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
|
||||
chatURL := envconfig.Host().String() + "/api/chat"
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup request",
|
||||
"url", chatURL,
|
||||
"req", anthropic.TraceChatRequest(&followUp),
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
|
||||
"status", resp.StatusCode,
|
||||
"response", strings.TrimSpace(string(respBody)),
|
||||
)
|
||||
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var chatResp api.ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
|
||||
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
|
||||
events := w.inner.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
switch e := event.Data.(type) {
|
||||
case anthropic.MessageStartEvent:
|
||||
w.streamMessageStarted = true
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
w.streamHasOpenBlock = true
|
||||
w.streamOpenBlockIndex = e.Index
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
|
||||
w.streamHasOpenBlock = false
|
||||
}
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.MessageStopEvent:
|
||||
w.terminalSent = true
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
|
||||
if w.streamMessageStarted {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = w.estimatedInputTokens
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: inputTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamMessageStarted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
|
||||
if !w.streamHasOpenBlock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: w.streamOpenBlockIndex,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = w.streamOpenBlockIndex + 1
|
||||
}
|
||||
w.streamHasOpenBlock = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
|
||||
for _, block := range content {
|
||||
index := w.streamNextIndex
|
||||
if block.Type == "text" {
|
||||
emptyText := ""
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: anthropic.ContentBlock{
|
||||
Type: "text",
|
||||
Text: &emptyText,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
text := ""
|
||||
if block.Text != nil {
|
||||
text = *block.Text
|
||||
}
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: anthropic.Delta{
|
||||
Type: "text_delta",
|
||||
Text: text,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: block,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: index,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamNextIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
|
||||
if w.terminalSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.closeOpenStreamBlock(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.writeStreamContentBlocks(response.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: anthropic.MessageDelta{
|
||||
StopReason: response.StopReason,
|
||||
},
|
||||
Usage: anthropic.DeltaUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamResponse emits a complete MessagesResponse as SSE events.
|
||||
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
toolUseID := serverToolUseID(w.inner.id)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: queryArgs(query),
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: errorCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// sendError sends a web search error response.
|
||||
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
|
||||
response := w.webSearchErrorResponse(errorCode, query, usage)
|
||||
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestCtx := c.Request.Context()
|
||||
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxTokens <= 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "max_tokens is required and must be positive"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, "messages is required"))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := anthropic.FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, anthropic.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
messageID := anthropic.GenerateMessageID()
|
||||
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
innerWriter := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
if hasWebSearchTool(req.Tools) {
|
||||
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
|
||||
// for cloud models. Local models may still receive web_search tool definitions;
|
||||
// execution is validated when the model actually emits a web_search tool call.
|
||||
if isCloudModelName(req.Model) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer = &WebSearchAnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
newLoopContext: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(requestCtx, 5*time.Minute)
|
||||
},
|
||||
inner: innerWriter,
|
||||
req: req,
|
||||
chatReq: chatReq,
|
||||
stream: req.Stream,
|
||||
estimatedInputTokens: estimatedTokens,
|
||||
}
|
||||
} else {
|
||||
c.Writer = innerWriter
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// hasWebSearchTool checks if the request tools include a web_search tool
|
||||
func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Type, "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return modelref.HasExplicitCloudSource(name)
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
func extractQueryFromToolCall(tc *api.ToolCall) string {
|
||||
q, ok := tc.Function.Arguments.Get("query")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := q.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeSSE writes a Server-Sent Event
|
||||
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryArgs creates a ToolCallFunctionArguments with a single "query" key.
|
||||
func queryArgs(query string) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("query", query)
|
||||
return args
|
||||
}
|
||||
|
||||
// serverToolUseID derives a server tool use ID from a message ID
|
||||
func serverToolUseID(messageID string) string {
|
||||
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
||||
}
|
||||
3006
middleware/anthropic_test.go
Normal file
3006
middleware/anthropic_test.go
Normal file
File diff suppressed because it is too large
Load Diff
790
middleware/openai.go
Normal file
790
middleware/openai.go
Normal file
@@ -0,0 +1,790 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
// maxDecompressedBodySize limits the size of a decompressed request body
|
||||
const maxDecompressedBodySize = 20 << 20
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
if err := json.Unmarshal(data, &serr); err != nil {
|
||||
// If the error response isn't valid JSON, use the raw bytes as the
|
||||
// error message rather than surfacing a confusing JSON parse error.
|
||||
serr.ErrorMessage = string(data)
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(w.ResponseWriter.Status(), serr.Error())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
chunks := openai.ToChunks(w.id, chatResponse, w.toolCallSent)
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
for _, c := range chunks {
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
if len(chunks) > 0 {
|
||||
c = chunks[len(chunks)-1]
|
||||
} else {
|
||||
slog.Warn("ToChunks returned no chunks; falling back to ToChunk for usage chunk", "id", w.id, "model", chatResponse.Model)
|
||||
}
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// chat completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
c := openai.ToCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &openai.Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse, w.encodingFormat))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.CompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq, err := openai.FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate encoding_format parameter
|
||||
if req.EncodingFormat != "" {
|
||||
if !strings.EqualFold(req.EncodingFormat, "float") && !strings.EqualFold(req.EncodingFormat, "base64") {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, fmt.Sprintf("Invalid value for 'encoding_format' = %s. Supported values: ['float', 'base64'].", req.EncodingFormat)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
encodingFormat: req.EncodingFormat,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := openai.FromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ResponsesWriter struct {
|
||||
BaseWriter
|
||||
converter *openai.ResponsesStreamConverter
|
||||
model string
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromResponsesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client requested streaming (defaults to false)
|
||||
streamRequested := req.Stream != nil && *req.Stream
|
||||
|
||||
// Pass streaming preference to the underlying chat request
|
||||
chatReq.Stream = &streamRequested
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
|
||||
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
if streamRequested {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
if err := json.Unmarshal(data, &generateResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Only write response when done with image
|
||||
if generateResponse.Done && generateResponse.Image != "" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageEditRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Image == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||
return
|
||||
}
|
||||
|
||||
genReq, err := openai.FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionWriter collects streamed chat responses and outputs a transcription response.
|
||||
type TranscriptionWriter struct {
|
||||
BaseWriter
|
||||
responseFormat string
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func (w *TranscriptionWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.text.WriteString(chatResponse.Message.Content)
|
||||
|
||||
if chatResponse.Done {
|
||||
text := strings.TrimSpace(w.text.String())
|
||||
|
||||
if w.responseFormat == "text" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/plain")
|
||||
_, err := w.ResponseWriter.Write([]byte(text))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
resp := openai.TranscriptionResponse{Text: text}
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(resp); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// TranscriptionMiddleware handles /v1/audio/transcriptions requests.
|
||||
// It accepts multipart/form-data with an audio file and converts it to a chat request.
|
||||
func TranscriptionMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Parse multipart form (limit 25MB).
|
||||
if err := c.Request.ParseMultipartForm(25 << 20); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to parse multipart form: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
model := c.Request.FormValue("model")
|
||||
if model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "file is required: "+err.Error()))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
audioData, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, "failed to read audio file"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(audioData) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "audio file is empty"))
|
||||
return
|
||||
}
|
||||
|
||||
req := openai.TranscriptionRequest{
|
||||
Model: model,
|
||||
AudioData: audioData,
|
||||
ResponseFormat: c.Request.FormValue("response_format"),
|
||||
Language: c.Request.FormValue("language"),
|
||||
Prompt: c.Request.FormValue("prompt"),
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromTranscriptionRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
c.Request.ContentLength = int64(b.Len())
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := &TranscriptionWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
responseFormat: req.ResponseFormat,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
220
middleware/openai_encoding_format_test.go
Normal file
220
middleware/openai_encoding_format_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
encodingFormat string
|
||||
expectType string // "array" or "string"
|
||||
verifyBase64 bool
|
||||
}{
|
||||
{"float format", "float", "array", false},
|
||||
{"base64 format", "base64", "string", true},
|
||||
{"default format", "", "array", false},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: [][]float32{{0.1, -0.2, 0.3}},
|
||||
PromptEvalCount: 5,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := `{"input": "test", "model": "test-model"`
|
||||
if tc.encodingFormat != "" {
|
||||
body += `, "encoding_format": "` + tc.encodingFormat + `"`
|
||||
}
|
||||
body += `}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result openai.EmbeddingList
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Data) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(result.Data))
|
||||
}
|
||||
|
||||
switch tc.expectType {
|
||||
case "array":
|
||||
if _, ok := result.Data[0].Embedding.([]interface{}); !ok {
|
||||
t.Errorf("expected array, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "string":
|
||||
embStr, ok := result.Data[0].Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("expected string, got %T", result.Data[0].Embedding)
|
||||
} else if tc.verifyBase64 {
|
||||
decoded, err := base64.StdEncoding.DecodeString(embStr)
|
||||
if err != nil {
|
||||
t.Errorf("invalid base64: %v", err)
|
||||
} else if len(decoded) != 12 {
|
||||
t.Errorf("expected 12 bytes, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware_BatchWithBase64(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: [][]float32{
|
||||
{0.1, 0.2},
|
||||
{0.3, 0.4},
|
||||
{0.5, 0.6},
|
||||
},
|
||||
PromptEvalCount: 10,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
body := `{
|
||||
"input": ["hello", "world", "test"],
|
||||
"model": "test-model",
|
||||
"encoding_format": "base64"
|
||||
}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result openai.EmbeddingList
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Data) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(result.Data))
|
||||
}
|
||||
|
||||
// All should be base64 strings
|
||||
for i := range 3 {
|
||||
embeddingStr, ok := result.Data[i].Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("embedding %d: expected string, got %T", i, result.Data[i].Embedding)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(embeddingStr); err != nil {
|
||||
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||
}
|
||||
|
||||
// Check index
|
||||
if result.Data[i].Index != i {
|
||||
t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
encodingFormat string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"valid: float", "float", false},
|
||||
{"valid: base64", "base64", false},
|
||||
{"valid: FLOAT (uppercase)", "FLOAT", false},
|
||||
{"valid: BASE64 (uppercase)", "BASE64", false},
|
||||
{"valid: Float (mixed)", "Float", false},
|
||||
{"valid: Base64 (mixed)", "Base64", false},
|
||||
{"invalid: json", "json", true},
|
||||
{"invalid: hex", "hex", true},
|
||||
{"invalid: invalid_format", "invalid_format", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := `{
|
||||
"input": "test",
|
||||
"model": "test-model",
|
||||
"encoding_format": "` + tc.encodingFormat + `"
|
||||
}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.shouldFail {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
|
||||
if !strings.Contains(errResp.Error.Message, "encoding_format") {
|
||||
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
|
||||
}
|
||||
} else {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1633
middleware/openai_test.go
Normal file
1633
middleware/openai_test.go
Normal file
File diff suppressed because it is too large
Load Diff
22
middleware/test_home_test.go
Normal file
22
middleware/test_home_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
|
||||
// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
|
||||
// so that cloud features are enabled for the duration of the test.
|
||||
func enableCloudForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
setTestHome(t, t.TempDir())
|
||||
}
|
||||
Reference in New Issue
Block a user