Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions pkg/llm-d-inference-sim/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,52 @@ var invalidTools = [][]openai.ChatCompletionToolParam{
},
}

var toolWithArray = []openai.ChatCompletionToolParam{
{
Function: openai.FunctionDefinitionParam{
Name: "multiply_numbers",
Description: openai.String("Multiply an array of numbers"),
Parameters: openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"numbers": map[string]interface{}{
"type": "array",
"items": map[string]string{"type": "number"},
"description": "List of numbers to multiply",
},
},
"required": []string{"numbers"},
},
},
},
}

var toolWith3DArray = []openai.ChatCompletionToolParam{
{
Function: openai.FunctionDefinitionParam{
Name: "process_tensor",
Description: openai.String("Process a 3D tensor of strings"),
Parameters: openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"tensor": map[string]interface{}{
"type": "array",
"items": map[string]any{
"type": "array",
"items": map[string]any{
"type": "array",
"items": map[string]string{"type": "string"},
},
},
"description": "List of strings",
},
},
"required": []string{"tensor"},
},
},
},
}

var _ = Describe("Simulator for request with tools", func() {

DescribeTable("streaming",
Expand Down Expand Up @@ -309,4 +355,105 @@ var _ = Describe("Simulator for request with tools", func() {
},
Entry(nil, modeRandom),
)

DescribeTable("array parameter, no streaming",
func(mode string) {
ctx := context.TODO()
client, err := startServer(ctx, mode)
Expect(err).NotTo(HaveOccurred())

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client))

params := openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)},
Model: model,
ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")},
Tools: toolWithArray,
}

resp, err := openaiclient.Chat.Completions.New(ctx, params)
Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

content := resp.Choices[0].Message.Content
Expect(content).Should(BeEmpty())

toolCalls := resp.Choices[0].Message.ToolCalls
Expect(toolCalls).To(HaveLen(1))
tc := toolCalls[0]
Expect(tc.Function.Name).To(Equal("multiply_numbers"))
Expect(tc.ID).NotTo(BeEmpty())
Expect(string(tc.Type)).To(Equal("function"))
args := make(map[string][]int)
err = json.Unmarshal([]byte(tc.Function.Arguments), &args)
Expect(err).NotTo(HaveOccurred())
Expect(args["numbers"]).ToNot(BeEmpty())
},
func(mode string) string {
return "mode: " + mode
},
// Call several times because the tools and arguments are chosen randomly
Entry(nil, modeRandom),
Entry(nil, modeRandom),
Entry(nil, modeRandom),
Entry(nil, modeRandom),
)

DescribeTable("3D array parameter, no streaming",
func(mode string) {
ctx := context.TODO()
client, err := startServer(ctx, mode)
Expect(err).NotTo(HaveOccurred())

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client))

params := openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)},
Model: model,
ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")},
Tools: toolWith3DArray,
}

resp, err := openaiclient.Chat.Completions.New(ctx, params)
Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

content := resp.Choices[0].Message.Content
Expect(content).Should(BeEmpty())

toolCalls := resp.Choices[0].Message.ToolCalls
Expect(toolCalls).To(HaveLen(1))
tc := toolCalls[0]
Expect(tc.Function.Name).To(Equal("process_tensor"))
Expect(tc.ID).NotTo(BeEmpty())
Expect(string(tc.Type)).To(Equal("function"))

args := make(map[string][][][]string)
err = json.Unmarshal([]byte(tc.Function.Arguments), &args)
Expect(err).NotTo(HaveOccurred())
Expect(args["tensor"]).ToNot(BeEmpty())
},
func(mode string) string {
return "mode: " + mode
},
// Call several times because the tools and arguments are chosen randomly
Entry(nil, modeRandom),
Entry(nil, modeRandom),
Entry(nil, modeRandom),
Entry(nil, modeRandom),
)
})
62 changes: 52 additions & 10 deletions pkg/llm-d-inference-sim/tools_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package llmdinferencesim
import (
"encoding/json"
"fmt"
"math/rand"
"time"

"github.com/santhosh-tekuri/jsonschema/v5"
)
Expand Down Expand Up @@ -89,11 +87,6 @@ func createToolCalls(tools []tool, toolChoice string) ([]toolCall, string, int,
return calls, toolsFinishReason, countTokensForToolCalls(calls), nil
}

func getStringArgument() string {
index := rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(fakeStringArguments))
return fakeStringArguments[index]
}

func generateToolArguments(tool tool) (map[string]any, error) {
arguments := make(map[string]any)
properties, _ := tool.Function.Parameters["properties"].(map[string]any)
Expand Down Expand Up @@ -144,11 +137,29 @@ func createArgument(property any) (any, error) {
return randomInt(100, false), nil
case "boolean":
return flipCoin(), nil
case "array":
items := propertyMap["items"]
itemsMap := items.(map[string]any)
numberOfElements := randomInt(5, true)
array := make([]any, numberOfElements)
for i := range numberOfElements {
elem, err := createArgument(itemsMap)
if err != nil {
return nil, err
}
array[i] = elem
}
return array, nil
default:
return nil, fmt.Errorf("tool parameters of type %s are currently not supported", paramType)
}
}

func getStringArgument() string {
index := randomInt(len(fakeStringArguments)-1, false)
return fakeStringArguments[index]
}

type validator struct {
schema *jsonschema.Schema
}
Expand Down Expand Up @@ -262,6 +273,7 @@ const schema = `{
"string",
"number",
"boolean",
"array",
"null"
]
},
Expand All @@ -275,12 +287,29 @@ const schema = `{
"string",
"number",
"boolean",
"array",
"null"
]
}
},
"additionalProperties": {
"type": "boolean"
"properties": {
"type": "object",
"additionalProperties": {
"$ref": "#/$defs/property_definition"
}
},
"items": {
"anyOf": [
{
"$ref": "#/$defs/property_definition"
},
{
"type": "array",
"items": {
"$ref": "#/$defs/property_definition"
}
}
]
}
},
"required": [
Expand Down Expand Up @@ -360,9 +389,22 @@ const schema = `{
]
}
}
},
{
"if": {
"properties": {
"type": {
"const": "array"
}
}
},
"then": {
"required": [
"items"
]
}
}
]
}
}
}
}`
2 changes: 1 addition & 1 deletion pkg/llm-d-inference-sim/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
// getRandomResponseText returns random response text from the pre-defined list of responses
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {
index := rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(chatCompletionFakeResponses))
index := randomInt(len(chatCompletionFakeResponses)-1, false)
text := chatCompletionFakeResponses[index]

return getResponseText(maxCompletionTokens, text)
Expand Down