ollama source for Momentry Core verification
This commit is contained in:
922
openai/openai.go
Normal file
922
openai/openai.go
Normal file
@@ -0,0 +1,922 @@
|
||||
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var finishReasonToolCalls = "tool_calls"
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ChoiceLogprobs struct {
|
||||
Content []api.Logprob `json:"content"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type ChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type CompleteChunkChoice struct {
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Temperature *float64 `json:"temperature"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty"`
|
||||
TopP *float64 `json:"top_p"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||
Tools []api.Tool `json:"tools"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
Logprobs *bool `json:"logprobs"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type ChatCompletion struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChunk struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []ChunkChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
Logprobs *int `json:"logprobs"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []CompleteChunkChoice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionChunk struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []CompleteChunkChoice `json:"choices"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Object string `json:"object"`
|
||||
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ListCompletion struct {
|
||||
Object string `json:"object"`
|
||||
Data []Model `json:"data"`
|
||||
}
|
||||
|
||||
type EmbeddingList struct {
|
||||
Object string `json:"object"`
|
||||
Data []Embedding `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
// ToUsage converts an api.ChatResponse to Usage
|
||||
func ToUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
|
||||
func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
toolCalls := make([]ToolCall, len(tc))
|
||||
for i, tc := range tc {
|
||||
toolCalls[i].ID = tc.ID
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
toolCalls[i].Index = tc.Function.Index
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("could not marshall function arguments to json", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
toolCalls[i].Function.Arguments = string(args)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
||||
|
||||
var logprobs *ChoiceLogprobs
|
||||
if len(r.Logprobs) > 0 {
|
||||
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
||||
}
|
||||
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
Created: r.CreatedAt.Unix(),
|
||||
Model: r.Model,
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []Choice{{
|
||||
Index: 0,
|
||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(toolCalls) > 0 {
|
||||
reason = "tool_calls"
|
||||
}
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
Logprobs: logprobs,
|
||||
}}, Usage: ToUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
||||
|
||||
var logprobs *ChoiceLogprobs
|
||||
if len(r.Logprobs) > 0 {
|
||||
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
||||
}
|
||||
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: r.Model,
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []ChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
if toolCallSent || len(toolCalls) > 0 {
|
||||
return &finishReasonToolCalls
|
||||
}
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
Logprobs: logprobs,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// ToChunks converts an api.ChatResponse to one or more ChatCompletionChunk values.
|
||||
func ToChunks(id string, r api.ChatResponse, toolCallSent bool) []ChatCompletionChunk {
|
||||
hasMixedResponse := r.Message.Thinking != "" && (r.Message.Content != "" || len(r.Message.ToolCalls) > 0)
|
||||
if !hasMixedResponse {
|
||||
return []ChatCompletionChunk{toChunk(id, r, toolCallSent)}
|
||||
}
|
||||
|
||||
reasoningChunk := toChunk(id, r, toolCallSent)
|
||||
// The logprobs here might include tokens not in this chunk because we now split between thinking and content/tool calls.
|
||||
reasoningChunk.Choices[0].Delta.Content = ""
|
||||
reasoningChunk.Choices[0].Delta.ToolCalls = nil
|
||||
reasoningChunk.Choices[0].FinishReason = nil
|
||||
|
||||
contentOrToolCallsChunk := toChunk(id, r, toolCallSent)
|
||||
// Keep both split chunks on the same timestamp since they represent one logical emission.
|
||||
contentOrToolCallsChunk.Created = reasoningChunk.Created
|
||||
contentOrToolCallsChunk.Choices[0].Delta.Reasoning = ""
|
||||
contentOrToolCallsChunk.Choices[0].Logprobs = nil
|
||||
|
||||
return []ChatCompletionChunk{
|
||||
reasoningChunk,
|
||||
contentOrToolCallsChunk,
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: use ToChunks for streaming conversion.
|
||||
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
return toChunk(id, r, toolCallSent)
|
||||
}
|
||||
|
||||
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ToCompletion converts an api.GenerateResponse to Completion
|
||||
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
Created: r.CreatedAt.Unix(),
|
||||
Model: r.Model,
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []CompleteChunkChoice{{
|
||||
Text: r.Response,
|
||||
Index: 0,
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: ToUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
||||
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
return CompletionChunk{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: r.Model,
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []CompleteChunkChoice{{
|
||||
Text: r.Response,
|
||||
Index: 0,
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||
var data []Model
|
||||
for _, m := range r.Models {
|
||||
data = append(data, Model{
|
||||
Id: m.Name,
|
||||
Object: "model",
|
||||
Created: m.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||
})
|
||||
}
|
||||
|
||||
return ListCompletion{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||
// encodingFormat can be "float", "base64", or empty (defaults to "float")
|
||||
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
|
||||
if r.Embeddings != nil {
|
||||
var data []Embedding
|
||||
for i, e := range r.Embeddings {
|
||||
var embedding any
|
||||
if strings.EqualFold(encodingFormat, "base64") {
|
||||
embedding = floatsToBase64(e)
|
||||
} else {
|
||||
embedding = e
|
||||
}
|
||||
|
||||
data = append(data, Embedding{
|
||||
Object: "embedding",
|
||||
Embedding: embedding,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
return EmbeddingList{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: model,
|
||||
Usage: EmbeddingUsage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
TotalTokens: r.PromptEvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return EmbeddingList{}
|
||||
}
|
||||
|
||||
// floatsToBase64 encodes a []float32 to a base64 string
|
||||
func floatsToBase64(floats []float32) string {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, floats)
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
}
|
||||
|
||||
// ToModel converts an api.ShowResponse to Model
|
||||
func ToModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
Id: m,
|
||||
Object: "model",
|
||||
Created: r.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m).Namespace,
|
||||
}
|
||||
}
|
||||
|
||||
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
for _, msg := range r.Messages {
|
||||
toolName := ""
|
||||
if strings.ToLower(msg.Role) == "tool" {
|
||||
toolName = msg.Name
|
||||
if toolName == "" && msg.ToolCallID != "" {
|
||||
toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
|
||||
}
|
||||
}
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName, ToolCallID: msg.ToolCallID})
|
||||
case []any:
|
||||
for _, c := range content {
|
||||
data, ok := c.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
switch data["type"] {
|
||||
case "text":
|
||||
text, ok := data["text"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||
case "image_url":
|
||||
var url string
|
||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||
if url, ok = urlMap["url"].(string); !ok {
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
} else {
|
||||
if url, ok = data["image_url"].(string); !ok {
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
}
|
||||
|
||||
img, err := decodeImageURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
case "input_audio":
|
||||
audioMap, ok := data["input_audio"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format")
|
||||
}
|
||||
b64Data, ok := audioMap["data"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid input_audio format: missing data")
|
||||
}
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(b64Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input_audio base64 data: %w", err)
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{audioBytes}})
|
||||
default:
|
||||
return nil, errors.New("invalid message format")
|
||||
}
|
||||
}
|
||||
// since we might have added multiple messages above, if we have tools
|
||||
// calls we'll add them to the last message
|
||||
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
|
||||
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages[len(messages)-1].ToolCalls = toolCalls
|
||||
messages[len(messages)-1].ToolName = toolName
|
||||
messages[len(messages)-1].ToolCallID = msg.ToolCallID
|
||||
messages[len(messages)-1].Thinking = msg.Reasoning
|
||||
}
|
||||
default:
|
||||
// content is only optional if tool calls are present
|
||||
if msg.ToolCalls == nil {
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolCallID: msg.ToolCallID})
|
||||
}
|
||||
}
|
||||
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
options["stop"] = []string{stop}
|
||||
case []any:
|
||||
var stops []string
|
||||
for _, s := range stop {
|
||||
if str, ok := s.(string); ok {
|
||||
stops = append(stops, str)
|
||||
}
|
||||
}
|
||||
options["stop"] = stops
|
||||
}
|
||||
|
||||
if r.MaxTokens != nil {
|
||||
options["num_predict"] = *r.MaxTokens
|
||||
}
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
} else {
|
||||
options["temperature"] = 1.0
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
if r.FrequencyPenalty != nil {
|
||||
options["frequency_penalty"] = *r.FrequencyPenalty
|
||||
}
|
||||
|
||||
if r.PresencePenalty != nil {
|
||||
options["presence_penalty"] = *r.PresencePenalty
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
} else {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var format json.RawMessage
|
||||
if r.ResponseFormat != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||
// Support the old "json_object" type for OpenAI compatibility
|
||||
case "json_object":
|
||||
format = json.RawMessage(`"json"`)
|
||||
case "json_schema":
|
||||
if r.ResponseFormat.JsonSchema != nil {
|
||||
format = r.ResponseFormat.JsonSchema.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
var effort string
|
||||
|
||||
if r.Reasoning != nil {
|
||||
effort = r.Reasoning.Effort
|
||||
} else if r.ReasoningEffort != nil {
|
||||
effort = *r.ReasoningEffort
|
||||
}
|
||||
|
||||
if effort != "" {
|
||||
if !slices.Contains([]string{"high", "medium", "low", "max", "none"}, effort) {
|
||||
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", \"max\", or \"none\")", effort)
|
||||
}
|
||||
|
||||
if effort == "none" {
|
||||
think = &api.ThinkValue{Value: false}
|
||||
} else {
|
||||
think = &api.ThinkValue{Value: effort}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
Logprobs: r.Logprobs != nil && *r.Logprobs,
|
||||
TopLogprobs: r.TopLogprobs,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
// iterate backwards to be more resilient to duplicate tool call IDs (this
|
||||
// follows "last one wins")
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msg := messages[i]
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID == toolCallID {
|
||||
return tc.Function.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
} else {
|
||||
valid := false
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
||||
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
||||
for i, tc := range toolCalls {
|
||||
apiToolCalls[i].ID = tc.ID
|
||||
apiToolCalls[i].Function.Name = tc.Function.Name
|
||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid tool call arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return apiToolCalls, nil
|
||||
}
|
||||
|
||||
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
||||
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
options["stop"] = []string{stop}
|
||||
case []any:
|
||||
var stops []string
|
||||
for _, s := range stop {
|
||||
if str, ok := s.(string); ok {
|
||||
stops = append(stops, str)
|
||||
} else {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
||||
}
|
||||
}
|
||||
options["stop"] = stops
|
||||
}
|
||||
|
||||
if r.MaxTokens != nil {
|
||||
options["num_predict"] = *r.MaxTokens
|
||||
}
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
} else {
|
||||
options["temperature"] = 1.0
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
options["frequency_penalty"] = r.FrequencyPenalty
|
||||
|
||||
options["presence_penalty"] = r.PresencePenalty
|
||||
|
||||
if r.TopP != 0.0 {
|
||||
options["top_p"] = r.TopP
|
||||
} else {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var logprobs bool
|
||||
var topLogprobs int
|
||||
if r.Logprobs != nil && *r.Logprobs > 0 {
|
||||
logprobs = true
|
||||
topLogprobs = *r.Logprobs
|
||||
}
|
||||
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
Logprobs: logprobs,
|
||||
TopLogprobs: topLogprobs,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
var data []ImageURLOrData
|
||||
if resp.Image != "" {
|
||||
data = []ImageURLOrData{{B64JSON: resp.Image}}
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionResponse is the response format for /v1/audio/transcriptions.
|
||||
type TranscriptionResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TranscriptionRequest holds parsed fields from the multipart form.
|
||||
type TranscriptionRequest struct {
|
||||
Model string
|
||||
AudioData []byte
|
||||
ResponseFormat string // "json", "text", "verbose_json"
|
||||
Language string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// FromTranscriptionRequest converts a transcription request into a ChatRequest
|
||||
// by wrapping the audio with a system prompt for transcription.
|
||||
func FromTranscriptionRequest(r TranscriptionRequest) (*api.ChatRequest, error) {
|
||||
systemPrompt := "Transcribe the following audio exactly as spoken. Output only the transcription text, nothing else."
|
||||
if r.Language != "" {
|
||||
systemPrompt += " The audio is in " + r.Language + "."
|
||||
}
|
||||
if r.Prompt != "" {
|
||||
systemPrompt += " Context: " + r.Prompt
|
||||
}
|
||||
|
||||
stream := true
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: "Transcribe this audio.", Images: []api.ImageData{r.AudioData}},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image string `json:"image"` // Base64-encoded image data
|
||||
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
|
||||
// Decode the input image
|
||||
if r.Image != "" {
|
||||
imgData, err := decodeImageURL(r.Image)
|
||||
if err != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||
}
|
||||
req.Images = append(req.Images, imgData)
|
||||
}
|
||||
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
139
openai/openai_encoding_format_test.go
Normal file
139
openai/openai_encoding_format_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestToEmbeddingList(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
embeddings [][]float32
|
||||
format string
|
||||
expectType string // "float" or "base64"
|
||||
expectBase64 []string
|
||||
expectCount int
|
||||
promptEval int
|
||||
}{
|
||||
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
|
||||
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
|
||||
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
|
||||
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
|
||||
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
|
||||
{"empty embeddings", nil, "float", "", nil, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: tc.embeddings,
|
||||
PromptEvalCount: tc.promptEval,
|
||||
}
|
||||
|
||||
result := ToEmbeddingList("test-model", resp, tc.format)
|
||||
|
||||
if tc.expectCount == 0 {
|
||||
if len(result.Data) != 0 {
|
||||
t.Errorf("expected 0 embeddings, got %d", len(result.Data))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(result.Data) != tc.expectCount {
|
||||
t.Fatalf("expected %d embeddings, got %d", tc.expectCount, len(result.Data))
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
// Check type of first embedding
|
||||
switch tc.expectType {
|
||||
case "float":
|
||||
if _, ok := result.Data[0].Embedding.([]float32); !ok {
|
||||
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "base64":
|
||||
for i, data := range result.Data {
|
||||
embStr, ok := data.Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
||||
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||
}
|
||||
|
||||
// Compare against expected base64 string if provided
|
||||
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
|
||||
if embStr != tc.expectBase64[i] {
|
||||
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check indices
|
||||
for i := range result.Data {
|
||||
if result.Data[i].Index != i {
|
||||
t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.promptEval > 0 && result.Usage.PromptTokens != tc.promptEval {
|
||||
t.Errorf("expected %d prompt tokens, got %d", tc.promptEval, result.Usage.PromptTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloatsToBase64(t *testing.T) {
|
||||
floats := []float32{0.1, -0.2, 0.3, -0.4, 0.5}
|
||||
|
||||
result := floatsToBase64(floats)
|
||||
|
||||
// Verify it's valid base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
// Check length
|
||||
expectedBytes := len(floats) * 4
|
||||
if len(decoded) != expectedBytes {
|
||||
t.Errorf("expected %d bytes, got %d", expectedBytes, len(decoded))
|
||||
}
|
||||
|
||||
// Decode and verify values
|
||||
for i, expected := range floats {
|
||||
offset := i * 4
|
||||
bits := uint32(decoded[offset]) |
|
||||
uint32(decoded[offset+1])<<8 |
|
||||
uint32(decoded[offset+2])<<16 |
|
||||
uint32(decoded[offset+3])<<24
|
||||
decodedFloat := math.Float32frombits(bits)
|
||||
|
||||
if math.Abs(float64(decodedFloat-expected)) > 1e-6 {
|
||||
t.Errorf("float[%d]: expected %f, got %f", i, expected, decodedFloat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloatsToBase64_EmptySlice(t *testing.T) {
|
||||
result := floatsToBase64([]float32{})
|
||||
|
||||
// Should return valid base64 for empty slice
|
||||
decoded, err := base64.StdEncoding.DecodeString(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("expected 0 bytes, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
867
openai/openai_test.go
Normal file
867
openai/openai_test.go
Normal file
@@ -0,0 +1,867 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
func TestFromChatRequest_Basic(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_ReasoningEffort(t *testing.T) {
|
||||
effort := func(s string) *string { return &s }
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
effort *string
|
||||
want any // expected ThinkValue.Value; nil means req.Think should be nil
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "unset", effort: nil, want: nil},
|
||||
{name: "high", effort: effort("high"), want: "high"},
|
||||
{name: "medium", effort: effort("medium"), want: "medium"},
|
||||
{name: "low", effort: effort("low"), want: "low"},
|
||||
{name: "max", effort: effort("max"), want: "max"},
|
||||
{name: "none disables", effort: effort("none"), want: false},
|
||||
{name: "invalid", effort: effort("extreme"), wantErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "hi"}},
|
||||
ReasoningEffort: tc.effort,
|
||||
}
|
||||
result, err := FromChatRequest(req)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for effort=%v, got none", *tc.effort)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tc.want == nil {
|
||||
if result.Think != nil {
|
||||
t.Fatalf("expected nil Think, got %+v", result.Think)
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.Think == nil {
|
||||
t.Fatalf("expected Think=%v, got nil", tc.want)
|
||||
}
|
||||
if result.Think.Value != tc.want {
|
||||
t.Fatalf("got Think.Value=%v, want %v", result.Think.Value, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(image)
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{"url": prefix + image},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[1].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[1].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromCompleteRequest_Basic(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
req := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
}
|
||||
|
||||
result, err := FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "Hello" {
|
||||
t.Errorf("expected prompt 'Hello', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToUsage(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 20,
|
||||
},
|
||||
}
|
||||
|
||||
usage := ToUsage(resp)
|
||||
|
||||
if usage.PromptTokens != 10 {
|
||||
t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens)
|
||||
}
|
||||
|
||||
if usage.CompletionTokens != 20 {
|
||||
t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens)
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 30 {
|
||||
t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{404, "not_found_error"},
|
||||
{500, "api_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||
original := []api.ToolCall{
|
||||
{
|
||||
ID: "call_abc123",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 2,
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Seattle",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_def456",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 7,
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"timezone": "UTC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolCalls := make([]api.ToolCall, len(original))
|
||||
copy(toolCalls, original)
|
||||
got := ToToolCalls(toolCalls)
|
||||
|
||||
if len(got) != len(original) {
|
||||
t.Fatalf("expected %d tool calls, got %d", len(original), len(got))
|
||||
}
|
||||
|
||||
expected := []ToolCall{
|
||||
{
|
||||
ID: "call_abc123",
|
||||
Type: "function",
|
||||
Index: 2,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: "get_weather",
|
||||
Arguments: `{"location":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_def456",
|
||||
Type: "function",
|
||||
Index: 7,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: "get_time",
|
||||
Arguments: `{"timezone":"UTC"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expected, got); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_WithLogprobs(t *testing.T) {
|
||||
trueVal := true
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Logprobs: &trueVal,
|
||||
TopLogprobs: 5,
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Logprobs {
|
||||
t.Error("expected Logprobs to be true")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 5 {
|
||||
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_LogprobsDefault(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Logprobs {
|
||||
t.Error("expected Logprobs to be false by default")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 0 {
|
||||
t.Errorf("expected TopLogprobs to be 0 by default, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromCompleteRequest_WithLogprobs(t *testing.T) {
|
||||
logprobsVal := 5
|
||||
|
||||
req := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Logprobs: &logprobsVal,
|
||||
}
|
||||
|
||||
result, err := FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Logprobs {
|
||||
t.Error("expected Logprobs to be true")
|
||||
}
|
||||
|
||||
if result.TopLogprobs != 5 {
|
||||
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatCompletion_WithLogprobs(t *testing.T) {
|
||||
createdAt := time.Unix(1234567890, 0)
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: createdAt,
|
||||
Message: api.Message{Role: "assistant", Content: "Hello there"},
|
||||
Logprobs: []api.Logprob{
|
||||
{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: "Hello",
|
||||
Logprob: -0.5,
|
||||
},
|
||||
TopLogprobs: []api.TokenLogprob{
|
||||
{Token: "Hello", Logprob: -0.5},
|
||||
{Token: "Hi", Logprob: -1.2},
|
||||
},
|
||||
},
|
||||
{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: " there",
|
||||
Logprob: -0.3,
|
||||
},
|
||||
TopLogprobs: []api.TokenLogprob{
|
||||
{Token: " there", Logprob: -0.3},
|
||||
{Token: " world", Logprob: -1.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
id := "test-id"
|
||||
|
||||
result := ToChatCompletion(id, resp)
|
||||
|
||||
if result.Id != id {
|
||||
t.Errorf("expected Id %q, got %q", id, result.Id)
|
||||
}
|
||||
|
||||
if result.Created != 1234567890 {
|
||||
t.Errorf("expected Created %d, got %d", int64(1234567890), result.Created)
|
||||
}
|
||||
|
||||
if len(result.Choices) != 1 {
|
||||
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
|
||||
}
|
||||
|
||||
choice := result.Choices[0]
|
||||
if choice.Message.Content != "Hello there" {
|
||||
t.Errorf("expected content %q, got %q", "Hello there", choice.Message.Content)
|
||||
}
|
||||
|
||||
if choice.Logprobs == nil {
|
||||
t.Fatal("expected Logprobs to be present")
|
||||
}
|
||||
|
||||
if len(choice.Logprobs.Content) != 2 {
|
||||
t.Fatalf("expected 2 logprobs, got %d", len(choice.Logprobs.Content))
|
||||
}
|
||||
|
||||
// Verify first logprob
|
||||
if choice.Logprobs.Content[0].Token != "Hello" {
|
||||
t.Errorf("expected first token %q, got %q", "Hello", choice.Logprobs.Content[0].Token)
|
||||
}
|
||||
if choice.Logprobs.Content[0].Logprob != -0.5 {
|
||||
t.Errorf("expected first logprob -0.5, got %f", choice.Logprobs.Content[0].Logprob)
|
||||
}
|
||||
if len(choice.Logprobs.Content[0].TopLogprobs) != 2 {
|
||||
t.Errorf("expected 2 top_logprobs, got %d", len(choice.Logprobs.Content[0].TopLogprobs))
|
||||
}
|
||||
|
||||
// Verify second logprob
|
||||
if choice.Logprobs.Content[1].Token != " there" {
|
||||
t.Errorf("expected second token %q, got %q", " there", choice.Logprobs.Content[1].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChatCompletion_WithoutLogprobs(t *testing.T) {
|
||||
createdAt := time.Unix(1234567890, 0)
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: createdAt,
|
||||
Message: api.Message{Role: "assistant", Content: "Hello"},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
id := "test-id"
|
||||
|
||||
result := ToChatCompletion(id, resp)
|
||||
|
||||
if len(result.Choices) != 1 {
|
||||
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
|
||||
}
|
||||
|
||||
// When no logprobs, Logprobs should be nil
|
||||
if result.Choices[0].Logprobs != nil {
|
||||
t.Error("expected Logprobs to be nil when not requested")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SplitsThinkingAndContent(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "step-by-step",
|
||||
Content: "final answer",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
|
||||
reasoning := chunks[0].Choices[0]
|
||||
if reasoning.Delta.Reasoning != "step-by-step" {
|
||||
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
|
||||
}
|
||||
if reasoning.Delta.Content != "" {
|
||||
t.Fatalf("expected reasoning chunk content to be empty, got %v", reasoning.Delta.Content)
|
||||
}
|
||||
if len(reasoning.Delta.ToolCalls) != 0 {
|
||||
t.Fatalf("expected reasoning chunk tool calls to be empty, got %d", len(reasoning.Delta.ToolCalls))
|
||||
}
|
||||
if reasoning.FinishReason != nil {
|
||||
t.Fatalf("expected reasoning chunk finish reason to be nil, got %q", *reasoning.FinishReason)
|
||||
}
|
||||
|
||||
content := chunks[1].Choices[0]
|
||||
if content.Delta.Reasoning != "" {
|
||||
t.Fatalf("expected content chunk reasoning to be empty, got %q", content.Delta.Reasoning)
|
||||
}
|
||||
if content.Delta.Content != "final answer" {
|
||||
t.Fatalf("expected content chunk content %q, got %v", "final answer", content.Delta.Content)
|
||||
}
|
||||
if content.FinishReason == nil || *content.FinishReason != "stop" {
|
||||
t.Fatalf("expected content chunk finish reason %q, got %v", "stop", content.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SplitsThinkingAndToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "need a tool",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Seattle",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
|
||||
reasoning := chunks[0].Choices[0]
|
||||
if reasoning.Delta.Reasoning != "need a tool" {
|
||||
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
|
||||
}
|
||||
if len(reasoning.Delta.ToolCalls) != 0 {
|
||||
t.Fatalf("expected reasoning chunk tool calls to be empty, got %d", len(reasoning.Delta.ToolCalls))
|
||||
}
|
||||
if reasoning.FinishReason != nil {
|
||||
t.Fatalf("expected reasoning chunk finish reason to be nil, got %q", *reasoning.FinishReason)
|
||||
}
|
||||
|
||||
toolCallChunk := chunks[1].Choices[0]
|
||||
if toolCallChunk.Delta.Reasoning != "" {
|
||||
t.Fatalf("expected tool-call chunk reasoning to be empty, got %q", toolCallChunk.Delta.Reasoning)
|
||||
}
|
||||
if len(toolCallChunk.Delta.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call in second chunk, got %d", len(toolCallChunk.Delta.ToolCalls))
|
||||
}
|
||||
if toolCallChunk.Delta.ToolCalls[0].ID != "call_123" {
|
||||
t.Fatalf("expected tool call id %q, got %q", "call_123", toolCallChunk.Delta.ToolCalls[0].ID)
|
||||
}
|
||||
if toolCallChunk.FinishReason == nil || *toolCallChunk.FinishReason != finishReasonToolCalls {
|
||||
t.Fatalf("expected tool-call chunk finish reason %q, got %v", finishReasonToolCalls, toolCallChunk.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SingleChunkForNonMixedResponses(t *testing.T) {
|
||||
toolCalls := []api.ToolCall{
|
||||
{
|
||||
ID: "call_456",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_time",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"timezone": "UTC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message api.Message
|
||||
}{
|
||||
{
|
||||
name: "thinking-only",
|
||||
message: api.Message{Thinking: "pondering"},
|
||||
},
|
||||
{
|
||||
name: "content-only",
|
||||
message: api.Message{Content: "hello"},
|
||||
},
|
||||
{
|
||||
name: "toolcalls-only",
|
||||
message: api.Message{ToolCalls: toolCalls},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: tt.message,
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(chunks))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SplitsThinkingAndToolCallsWhenNotDone(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "need a tool",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_789",
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
|
||||
reasoning := chunks[0].Choices[0]
|
||||
if reasoning.Delta.Reasoning != "need a tool" {
|
||||
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
|
||||
}
|
||||
if reasoning.FinishReason != nil {
|
||||
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoning.FinishReason)
|
||||
}
|
||||
|
||||
toolCallChunk := chunks[1].Choices[0]
|
||||
if len(toolCallChunk.Delta.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call in second chunk, got %d", len(toolCallChunk.Delta.ToolCalls))
|
||||
}
|
||||
if toolCallChunk.Delta.ToolCalls[0].ID != "call_789" {
|
||||
t.Fatalf("expected tool call id %q, got %q", "call_789", toolCallChunk.Delta.ToolCalls[0].ID)
|
||||
}
|
||||
if toolCallChunk.FinishReason != nil {
|
||||
t.Fatalf("expected tool-call chunk finish reason nil when not done, got %v", toolCallChunk.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SplitsThinkingAndContentWhenNotDone(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "thinking",
|
||||
Content: "partial content",
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
|
||||
reasoning := chunks[0].Choices[0]
|
||||
if reasoning.Delta.Reasoning != "thinking" {
|
||||
t.Fatalf("expected reasoning chunk to contain thinking, got %q", reasoning.Delta.Reasoning)
|
||||
}
|
||||
if reasoning.FinishReason != nil {
|
||||
t.Fatalf("expected reasoning chunk finish reason nil, got %v", reasoning.FinishReason)
|
||||
}
|
||||
|
||||
content := chunks[1].Choices[0]
|
||||
if content.Delta.Content != "partial content" {
|
||||
t.Fatalf("expected content chunk content %q, got %v", "partial content", content.Delta.Content)
|
||||
}
|
||||
if content.FinishReason != nil {
|
||||
t.Fatalf("expected content chunk finish reason nil when not done, got %v", content.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks_SplitSendsLogprobsOnlyOnFirstChunk(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "thinking",
|
||||
Content: "content",
|
||||
},
|
||||
Logprobs: []api.Logprob{
|
||||
{
|
||||
TokenLogprob: api.TokenLogprob{
|
||||
Token: "tok",
|
||||
Logprob: -0.25,
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
chunks := ToChunks("test-id", resp, false)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
|
||||
first := chunks[0].Choices[0]
|
||||
if first.Logprobs == nil {
|
||||
t.Fatal("expected first chunk to include logprobs")
|
||||
}
|
||||
if len(first.Logprobs.Content) != 1 || first.Logprobs.Content[0].Token != "tok" {
|
||||
t.Fatalf("unexpected first chunk logprobs: %+v", first.Logprobs.Content)
|
||||
}
|
||||
|
||||
second := chunks[1].Choices[0]
|
||||
if second.Logprobs != nil {
|
||||
t.Fatalf("expected second chunk logprobs to be nil, got %+v", second.Logprobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunk_LegacyMixedThinkingAndContentSingleChunk(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Thinking: "reasoning",
|
||||
Content: "answer",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
chunk := ToChunk("test-id", resp, false)
|
||||
if len(chunk.Choices) != 1 {
|
||||
t.Fatalf("expected 1 choice, got %d", len(chunk.Choices))
|
||||
}
|
||||
|
||||
delta := chunk.Choices[0].Delta
|
||||
if delta.Reasoning != "reasoning" {
|
||||
t.Fatalf("expected reasoning %q, got %q", "reasoning", delta.Reasoning)
|
||||
}
|
||||
if delta.Content != "answer" {
|
||||
t.Fatalf("expected content %q, got %v", "answer", delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
topLogprobs int
|
||||
expectValid bool
|
||||
}{
|
||||
{name: "valid: 0", topLogprobs: 0, expectValid: true},
|
||||
{name: "valid: 1", topLogprobs: 1, expectValid: true},
|
||||
{name: "valid: 10", topLogprobs: 10, expectValid: true},
|
||||
{name: "valid: 20", topLogprobs: 20, expectValid: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
trueVal := true
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
Logprobs: &trueVal,
|
||||
TopLogprobs: tt.topLogprobs,
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.TopLogprobs != tt.topLogprobs {
|
||||
t.Errorf("expected TopLogprobs %d, got %d", tt.topLogprobs, result.TopLogprobs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "make it blue" {
|
||||
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if len(result.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Size: "512x768",
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Width != 512 {
|
||||
t.Errorf("expected width 512, got %d", result.Width)
|
||||
}
|
||||
|
||||
if result.Height != 768 {
|
||||
t.Errorf("expected height 768, got %d", result.Height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||
seed := int64(12345)
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Seed: &seed,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options == nil {
|
||||
t.Fatal("expected options to be set")
|
||||
}
|
||||
|
||||
if result.Options["seed"] != seed {
|
||||
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: "not-valid-base64",
|
||||
}
|
||||
|
||||
_, err := FromImageEditRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid image")
|
||||
}
|
||||
}
|
||||
1382
openai/responses.go
Normal file
1382
openai/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
2051
openai/responses_test.go
Normal file
2051
openai/responses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user