diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index 3e21eb9a..35b6495b 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -131,6 +131,52 @@ var invalidTools = [][]openai.ChatCompletionToolParam{ }, } +var toolWithArray = []openai.ChatCompletionToolParam{ + { + Function: openai.FunctionDefinitionParam{ + Name: "multiply_numbers", + Description: openai.String("Multiply an array of numbers"), + Parameters: openai.FunctionParameters{ + "type": "object", + "properties": map[string]interface{}{ + "numbers": map[string]interface{}{ + "type": "array", + "items": map[string]string{"type": "number"}, + "description": "List of numbers to multiply", + }, + }, + "required": []string{"numbers"}, + }, + }, + }, +} + +var toolWith3DArray = []openai.ChatCompletionToolParam{ + { + Function: openai.FunctionDefinitionParam{ + Name: "process_tensor", + Description: openai.String("Process a 3D tensor of strings"), + Parameters: openai.FunctionParameters{ + "type": "object", + "properties": map[string]interface{}{ + "tensor": map[string]interface{}{ + "type": "array", + "items": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "array", + "items": map[string]string{"type": "string"}, + }, + }, + "description": "List of strings", + }, + }, + "required": []string{"tensor"}, + }, + }, + }, +} + var _ = Describe("Simulator for request with tools", func() { DescribeTable("streaming", @@ -309,4 +355,105 @@ var _ = Describe("Simulator for request with tools", func() { }, Entry(nil, modeRandom), ) + + DescribeTable("array parameter, no streaming", + func(mode string) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, + Model: model, + ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, + Tools: toolWithArray, + } + + resp, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Choices).ShouldNot(BeEmpty()) + Expect(string(resp.Object)).To(Equal(chatCompletionObject)) + + Expect(resp.Usage.PromptTokens).To(Equal(int64(4))) + Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0)) + Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens)) + + content := resp.Choices[0].Message.Content + Expect(content).Should(BeEmpty()) + + toolCalls := resp.Choices[0].Message.ToolCalls + Expect(toolCalls).To(HaveLen(1)) + tc := toolCalls[0] + Expect(tc.Function.Name).To(Equal("multiply_numbers")) + Expect(tc.ID).NotTo(BeEmpty()) + Expect(string(tc.Type)).To(Equal("function")) + args := make(map[string][]int) + err = json.Unmarshal([]byte(tc.Function.Arguments), &args) + Expect(err).NotTo(HaveOccurred()) + Expect(args["numbers"]).ToNot(BeEmpty()) + }, + func(mode string) string { + return "mode: " + mode + }, + // Call several times because the tools and arguments are chosen randomly + Entry(nil, modeRandom), + Entry(nil, modeRandom), + Entry(nil, modeRandom), + Entry(nil, modeRandom), + ) + + DescribeTable("3D array parameter, no streaming", + func(mode string) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, + Model: model, + ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, + Tools: toolWith3DArray, + } + + resp, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Choices).ShouldNot(BeEmpty()) + Expect(string(resp.Object)).To(Equal(chatCompletionObject)) + + Expect(resp.Usage.PromptTokens).To(Equal(int64(4))) + Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0)) + Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens)) + + content := resp.Choices[0].Message.Content + Expect(content).Should(BeEmpty()) + + toolCalls := resp.Choices[0].Message.ToolCalls + Expect(toolCalls).To(HaveLen(1)) + tc := toolCalls[0] + Expect(tc.Function.Name).To(Equal("process_tensor")) + Expect(tc.ID).NotTo(BeEmpty()) + Expect(string(tc.Type)).To(Equal("function")) + + args := make(map[string][][][]string) + err = json.Unmarshal([]byte(tc.Function.Arguments), &args) + Expect(err).NotTo(HaveOccurred()) + Expect(args["tensor"]).ToNot(BeEmpty()) + }, + func(mode string) string { + return "mode: " + mode + }, + // Call several times because the tools and arguments are chosen randomly + Entry(nil, modeRandom), + Entry(nil, modeRandom), + Entry(nil, modeRandom), + Entry(nil, modeRandom), + ) }) diff --git a/pkg/llm-d-inference-sim/tools_utils.go b/pkg/llm-d-inference-sim/tools_utils.go index 144b4d3f..635ac3cf 100644 --- a/pkg/llm-d-inference-sim/tools_utils.go +++ b/pkg/llm-d-inference-sim/tools_utils.go @@ -19,8 +19,6 @@ package llmdinferencesim import ( "encoding/json" "fmt" - "math/rand" - "time" "github.com/santhosh-tekuri/jsonschema/v5" ) @@ -89,11 +87,6 @@ func createToolCalls(tools []tool, toolChoice string) ([]toolCall, string, int, return calls, toolsFinishReason, countTokensForToolCalls(calls), nil } -func getStringArgument() string { - index := rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(fakeStringArguments)) - return fakeStringArguments[index] -} - func generateToolArguments(tool tool) (map[string]any, error) { arguments := make(map[string]any) properties, _ := tool.Function.Parameters["properties"].(map[string]any) @@ -144,11 +137,29 @@ func createArgument(property any) (any, error) { return randomInt(100, false), nil case "boolean": return flipCoin(), nil + case "array": + items := propertyMap["items"] + itemsMap := items.(map[string]any) + numberOfElements := randomInt(5, true) + array := make([]any, numberOfElements) + for i := range numberOfElements { + elem, err := createArgument(itemsMap) + if err != nil { + return nil, err + } + array[i] = elem + } + return array, nil default: return nil, fmt.Errorf("tool parameters of type %s are currently not supported", paramType) } } +func getStringArgument() string { + index := randomInt(len(fakeStringArguments)-1, false) + return fakeStringArguments[index] +} + type validator struct { schema *jsonschema.Schema } @@ -262,6 +273,7 @@ const schema = `{ "string", "number", "boolean", + "array", "null" ] }, @@ -275,12 +287,29 @@ const schema = `{ "string", "number", "boolean", + "array", "null" ] } }, - "additionalProperties": { - "type": "boolean" + "properties": { + "type": "object", + "additionalProperties": { + "$ref": "#/$defs/property_definition" + } + }, + "items": { + "anyOf": [ + { + "$ref": "#/$defs/property_definition" + }, + { + "type": "array", + "items": { + "$ref": "#/$defs/property_definition" + } + } + ] } }, "required": [ @@ -360,9 +389,22 @@ const schema = `{ ] } } + }, + { + "if": { + "properties": { + "type": { + "const": "array" + } + } + }, + "then": { + "required": [ + "items" + ] + } } ] } } -} }` diff --git a/pkg/llm-d-inference-sim/utils.go b/pkg/llm-d-inference-sim/utils.go index 0bc6458c..e76e0d20 100644 --- a/pkg/llm-d-inference-sim/utils.go +++ b/pkg/llm-d-inference-sim/utils.go @@ -62,7 +62,7 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) // getRandomResponseText returns random response text from the pre-defined list of responses // considering max completion tokens if it is not nil, and a finish reason (stop or length) func getRandomResponseText(maxCompletionTokens *int64) (string, string) { - index := rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(chatCompletionFakeResponses)) + index := randomInt(len(chatCompletionFakeResponses)-1, false) text := chatCompletionFakeResponses[index] return getResponseText(maxCompletionTokens, text)