Skip to content

Commit 0cd2025

Browse files
committed
feat: add support for tool calls and enhance diff application functionality
- Introduced `applyDiffToolLegacy` to handle diff content formatting based on tool use parameters. - Enhanced `attemptCompletionTool` to push tool result messages when tool calls are enabled. - Updated `multiApplyDiffTool` to manage diffs with various content types and conditions. - Added `ToolCallSettingsControl` component for user settings regarding tool calls. - Implemented localization for tool call settings in multiple languages. - Updated relevant tests to cover new functionality and ensure proper behavior with tool calls.
1 parent 195f4eb commit 0cd2025

File tree

85 files changed

+4926
-150
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+4926
-150
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ const baseProviderSettingsSchema = z.object({
100100
includeMaxTokens: z.boolean().optional(),
101101
diffEnabled: z.boolean().optional(),
102102
todoListEnabled: z.boolean().optional(),
103+
toolCallEnabled: z.boolean().optional(),
103104
fuzzyMatchThreshold: z.number().optional(),
104105
modelTemperature: z.number().nullish(),
105106
rateLimitSeconds: z.number().optional(),

src/api/index.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22

3-
import type { ProviderSettings, ModelInfo } from "@roo-code/types"
3+
import type { ProviderSettings, ModelInfo, ToolName } from "@roo-code/types"
44

55
import { ApiStream } from "./transform/stream"
66

@@ -42,6 +42,7 @@ import {
4242
DeepInfraHandler,
4343
} from "./providers"
4444
import { NativeOllamaHandler } from "./providers/native-ollama"
45+
import { ToolArgs } from "../core/prompts/tools/types"
4546

4647
export interface SingleCompletionHandler {
4748
completePrompt(prompt: string): Promise<string>
@@ -65,6 +66,14 @@ export interface ApiHandlerCreateMessageMetadata {
6566
* @default true
6667
*/
6768
store?: boolean
69+
/**
70+
* tool call
71+
*/
72+
tools?: ToolName[]
73+
/**
74+
* tool call args
75+
*/
76+
toolArgs?: ToolArgs
6877
}
6978

7079
export interface ApiHandler {

src/api/providers/__tests__/lmstudio.spec.ts

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,115 @@ describe("LmStudioHandler", () => {
164164
expect(modelInfo.info.contextWindow).toBe(128_000)
165165
})
166166
})
167+
describe("LmStudioHandler Tool Calling", () => {
168+
let handler: LmStudioHandler
169+
let mockOptions: ApiHandlerOptions
170+
171+
beforeEach(() => {
172+
mockOptions = {
173+
apiModelId: "local-model",
174+
lmStudioModelId: "local-model",
175+
lmStudioBaseUrl: "http://localhost:1234",
176+
}
177+
handler = new LmStudioHandler(mockOptions)
178+
mockCreate.mockClear()
179+
})
180+
181+
describe("createMessage with tool calls", () => {
182+
const systemPrompt = "You are a helpful assistant."
183+
const messages: Anthropic.Messages.MessageParam[] = [
184+
{
185+
role: "user",
186+
content: "Hello!",
187+
},
188+
]
189+
190+
it("should include tool call parameters when tools are provided", async () => {
191+
mockCreate.mockImplementation(async function* () {
192+
yield {
193+
choices: [
194+
{
195+
delta: { content: "Test response" },
196+
index: 0,
197+
},
198+
],
199+
usage: null,
200+
}
201+
})
202+
203+
const stream = handler.createMessage(systemPrompt, messages, {
204+
tools: ["test_tool" as any],
205+
taskId: "test-task-id",
206+
})
207+
208+
// Consume the stream
209+
for await (const _ of stream) {
210+
//
211+
}
212+
213+
expect(mockCreate).toHaveBeenCalledWith(
214+
expect.objectContaining({
215+
tools: expect.any(Array),
216+
tool_choice: "auto",
217+
}),
218+
)
219+
})
220+
221+
it("should yield tool_call chunks when model returns tool calls", async () => {
222+
const toolCallChunk = {
223+
choices: [
224+
{
225+
delta: {
226+
tool_calls: [
227+
{
228+
index: 0,
229+
id: "tool-call-1",
230+
type: "function",
231+
function: {
232+
name: "test_tool",
233+
arguments: '{"param1":"value1"}',
234+
},
235+
},
236+
],
237+
},
238+
index: 0,
239+
},
240+
],
241+
}
242+
const finalChunk = {
243+
choices: [
244+
{
245+
delta: {},
246+
finish_reason: "tool_calls",
247+
},
248+
],
249+
usage: {
250+
prompt_tokens: 10,
251+
completion_tokens: 5,
252+
total_tokens: 15,
253+
},
254+
}
255+
256+
mockCreate.mockImplementation(async function* () {
257+
yield toolCallChunk
258+
yield finalChunk
259+
})
260+
261+
const stream = handler.createMessage(systemPrompt, messages, {
262+
tools: ["test_tool" as any],
263+
taskId: "test-task-id",
264+
})
265+
266+
const chunks: any[] = []
267+
for await (const chunk of stream) {
268+
chunks.push(chunk)
269+
}
270+
271+
const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
272+
expect(toolCallChunks.length).toBe(1)
273+
expect(toolCallChunks[0].toolCalls).toEqual(toolCallChunk.choices[0].delta.tool_calls)
274+
expect(toolCallChunks[0].toolCallType).toBe("openai")
275+
})
276+
})
277+
})
167278
})
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// npx vitest run api/providers/__tests__/openai-tool-call.spec.ts
2+
3+
import { OpenAiHandler } from "../openai"
4+
import { ApiHandlerOptions } from "../../../shared/api"
5+
import OpenAI from "openai"
6+
import { getToolRegistry } from "../../../core/prompts/tools/schemas/tool-registry"
7+
import { ToolName } from "@roo-code/types"
8+
9+
const mockCreate = vitest.fn()
10+
const mockGenerateFunctionCallSchemas = vitest.fn()
11+
12+
vitest.mock("openai", () => {
13+
const mockConstructor = vitest.fn()
14+
return {
15+
__esModule: true,
16+
default: mockConstructor.mockImplementation(() => ({
17+
chat: {
18+
completions: {
19+
create: mockCreate,
20+
},
21+
},
22+
})),
23+
}
24+
})
25+
26+
vitest.mock("../../../core/prompts/tools/schemas/tool-registry", () => ({
27+
getToolRegistry: () => ({
28+
generateFunctionCallSchemas: mockGenerateFunctionCallSchemas,
29+
}),
30+
}))
31+
32+
describe("OpenAiHandler Tool Call", () => {
33+
let handler: OpenAiHandler
34+
let mockOptions: ApiHandlerOptions
35+
36+
beforeEach(() => {
37+
mockOptions = {
38+
openAiApiKey: "test-api-key",
39+
openAiModelId: "gpt-4",
40+
openAiBaseUrl: "https://api.openai.com/v1",
41+
}
42+
handler = new OpenAiHandler(mockOptions)
43+
mockCreate.mockClear()
44+
mockGenerateFunctionCallSchemas.mockClear()
45+
})
46+
47+
it("should include tools and tool_choice in the request when metadata.tools are provided", async () => {
48+
const systemPrompt = "You are a helpful assistant."
49+
const messages = [
50+
{
51+
role: "user" as const,
52+
content: "Hello!",
53+
},
54+
]
55+
const metadata = {
56+
taskId: "test-task-id",
57+
tools: ["read_file" as ToolName],
58+
toolArgs: { cwd: ".", supportsComputerUse: true },
59+
}
60+
61+
mockGenerateFunctionCallSchemas.mockReturnValue([
62+
{
63+
type: "function" as const,
64+
function: {
65+
name: "read_file",
66+
description: "A function to interact with files.",
67+
parameters: {},
68+
},
69+
},
70+
])
71+
72+
mockCreate.mockImplementation(async function* () {
73+
yield {
74+
choices: [
75+
{
76+
delta: { content: "Test response" },
77+
index: 0,
78+
},
79+
],
80+
usage: null,
81+
}
82+
})
83+
84+
const stream = handler.createMessage(systemPrompt, messages, metadata)
85+
86+
for await (const _ of stream) {
87+
// Consume stream
88+
}
89+
90+
expect(mockCreate).toHaveBeenCalledWith(
91+
expect.objectContaining({
92+
tools: [
93+
{
94+
type: "function",
95+
function: {
96+
name: "read_file",
97+
description: "A function to interact with files.",
98+
parameters: {},
99+
},
100+
},
101+
],
102+
tool_choice: "auto",
103+
}),
104+
expect.any(Object),
105+
)
106+
})
107+
108+
it("should yield a tool_call event when the API returns tool_calls", async () => {
109+
const systemPrompt = "You are a helpful assistant."
110+
const messages = [
111+
{
112+
role: "user" as const,
113+
content: "Hello!",
114+
},
115+
]
116+
const metadata = {
117+
taskId: "test-task-id",
118+
tools: ["write_to_file" as ToolName],
119+
toolArgs: { cwd: ".", supportsComputerUse: true },
120+
}
121+
122+
mockCreate.mockImplementation(async function* () {
123+
yield {
124+
choices: [
125+
{
126+
delta: {
127+
tool_calls: [
128+
{
129+
index: 0,
130+
id: "call_123",
131+
type: "function",
132+
function: {
133+
name: "write_to_file",
134+
arguments: '{"query":"test"}',
135+
},
136+
},
137+
],
138+
},
139+
index: 0,
140+
},
141+
],
142+
}
143+
})
144+
145+
const stream = handler.createMessage(systemPrompt, messages, metadata)
146+
const chunks: any[] = []
147+
for await (const chunk of stream) {
148+
chunks.push(chunk)
149+
}
150+
151+
const toolCallChunk = chunks.find((chunk) => chunk.type === "tool_call")
152+
153+
expect(toolCallChunk).toBeDefined()
154+
expect(toolCallChunk.toolCalls).toEqual([
155+
{
156+
index: 0,
157+
id: "call_123",
158+
type: "function",
159+
function: {
160+
name: "write_to_file",
161+
arguments: '{"query":"test"}',
162+
},
163+
},
164+
])
165+
})
166+
167+
it("should not include tools and tool_choice in the request when metadata.tools are not provided", async () => {
168+
const systemPrompt = "You are a helpful assistant."
169+
const messages = [
170+
{
171+
role: "user" as const,
172+
content: "Hello!",
173+
},
174+
]
175+
176+
mockCreate.mockImplementation(async function* () {
177+
yield {
178+
choices: [
179+
{
180+
delta: { content: "Test response" },
181+
index: 0,
182+
},
183+
],
184+
usage: null,
185+
}
186+
})
187+
188+
const stream = handler.createMessage(systemPrompt, messages)
189+
for await (const _ of stream) {
190+
// Consume stream
191+
}
192+
193+
expect(mockCreate).toHaveBeenCalledWith(
194+
expect.not.objectContaining({
195+
tools: expect.any(Array),
196+
tool_choice: expect.any(String),
197+
}),
198+
expect.any(Object),
199+
)
200+
})
201+
})

0 commit comments

Comments
 (0)