Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/cloud/src/bridge/BaseChannel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export abstract class BaseChannel<TCommand = unknown, TEventName extends string
/**
* Handle incoming commands - must be implemented by subclasses.
*/
public abstract handleCommand(command: TCommand): void
public abstract handleCommand(command: TCommand): Promise<void>

/**
* Handle connection-specific logic.
Expand Down
44 changes: 30 additions & 14 deletions packages/cloud/src/bridge/ExtensionChannel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ export class ExtensionChannel extends BaseChannel<
this.setupListeners()
}

/**
* Handle extension-specific commands from the web app
*/
public handleCommand(command: ExtensionBridgeCommand): void {
public async handleCommand(command: ExtensionBridgeCommand): Promise<void> {
if (command.instanceId !== this.instanceId) {
console.log(`[ExtensionChannel] command -> instance id mismatch | ${this.instanceId}`, {
messageInstanceId: command.instanceId,
Expand All @@ -69,13 +66,22 @@ export class ExtensionChannel extends BaseChannel<
console.log(`[ExtensionChannel] command -> createTask() | ${command.instanceId}`, {
text: command.payload.text?.substring(0, 100) + "...",
hasImages: !!command.payload.images,
mode: command.payload.mode,
providerProfile: command.payload.providerProfile,
})

this.provider.createTask(command.payload.text, command.payload.images)
this.provider.createTask(
command.payload.text,
command.payload.images,
undefined, // parentTask
undefined, // options
{ mode: command.payload.mode, currentApiConfigName: command.payload.providerProfile },
)

break
}
case ExtensionBridgeCommandName.StopTask: {
const instance = this.updateInstance()
const instance = await this.updateInstance()

if (instance.task.taskStatus === TaskStatus.Running) {
console.log(`[ExtensionChannel] command -> cancelTask() | ${command.instanceId}`)
Expand All @@ -86,14 +92,14 @@ export class ExtensionChannel extends BaseChannel<
this.provider.clearTask()
this.provider.postStateToWebview()
}

break
}
case ExtensionBridgeCommandName.ResumeTask: {
console.log(`[ExtensionChannel] command -> resumeTask() | ${command.instanceId}`, {
taskId: command.payload.taskId,
})

// Resume the task from history by taskId
this.provider.resumeTask(command.payload.taskId)
this.provider.postStateToWebview()
break
Expand Down Expand Up @@ -122,20 +128,20 @@ export class ExtensionChannel extends BaseChannel<
}

private async registerInstance(_socket: Socket): Promise<void> {
const instance = this.updateInstance()
const instance = await this.updateInstance()
await this.publish(ExtensionSocketEvents.REGISTER, instance)
}

private async unregisterInstance(_socket: Socket): Promise<void> {
const instance = this.updateInstance()
const instance = await this.updateInstance()
await this.publish(ExtensionSocketEvents.UNREGISTER, instance)
}

private startHeartbeat(socket: Socket): void {
this.stopHeartbeat()

this.heartbeatInterval = setInterval(async () => {
const instance = this.updateInstance()
const instance = await this.updateInstance()

try {
socket.emit(ExtensionSocketEvents.HEARTBEAT, instance)
Expand Down Expand Up @@ -172,11 +178,11 @@ export class ExtensionChannel extends BaseChannel<
] as const

eventMapping.forEach(({ from, to }) => {
// Create and store the listener function for cleanup/
const listener = (..._args: unknown[]) => {
// Create and store the listener function for cleanup.
const listener = async (..._args: unknown[]) => {
this.publish(ExtensionSocketEvents.EVENT, {
type: to,
instance: this.updateInstance(),
instance: await this.updateInstance(),
timestamp: Date.now(),
})
}
Expand All @@ -195,10 +201,16 @@ export class ExtensionChannel extends BaseChannel<
this.eventListeners.clear()
}

private updateInstance(): ExtensionInstance {
private async updateInstance(): Promise<ExtensionInstance> {
const task = this.provider?.getCurrentTask()
const taskHistory = this.provider?.getRecentTasks() ?? []

const mode = await this.provider?.getMode()
const modes = (await this.provider?.getModes()) ?? []

const providerProfile = await this.provider?.getProviderProfile()
const providerProfiles = (await this.provider?.getProviderProfiles()) ?? []

this.extensionInstance = {
...this.extensionInstance,
appProperties: this.extensionInstance.appProperties ?? this.provider.appProperties,
Expand All @@ -213,6 +225,10 @@ export class ExtensionChannel extends BaseChannel<
: { taskId: "", taskStatus: TaskStatus.None },
taskAsk: task?.taskAsk,
taskHistory,
mode,
providerProfile,
modes,
providerProfiles,
}

return this.extensionInstance
Expand Down
12 changes: 10 additions & 2 deletions packages/cloud/src/bridge/TaskChannel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export class TaskChannel extends BaseChannel<
super(instanceId)
}

public handleCommand(command: TaskBridgeCommand): void {
public async handleCommand(command: TaskBridgeCommand): Promise<void> {
const task = this.subscribedTasks.get(command.taskId)

if (!task) {
Expand All @@ -87,14 +87,22 @@ export class TaskChannel extends BaseChannel<
`[TaskChannel] ${TaskBridgeCommandName.Message} ${command.taskId} -> submitUserMessage()`,
command,
)
task.submitUserMessage(command.payload.text, command.payload.images)

await task.submitUserMessage(
command.payload.text,
command.payload.images,
command.payload.mode,
command.payload.providerProfile,
)

break

case TaskBridgeCommandName.ApproveAsk:
console.log(
`[TaskChannel] ${TaskBridgeCommandName.ApproveAsk} ${command.taskId} -> approveAsk()`,
command,
)

task.approveAsk(command.payload)
break

Expand Down
12 changes: 11 additions & 1 deletion packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ describe("ExtensionChannel", () => {
postStateToWebview: vi.fn(),
postMessageToWebview: vi.fn(),
getTelemetryProperties: vi.fn(),
getMode: vi.fn().mockResolvedValue("code"),
getModes: vi.fn().mockResolvedValue([
{ slug: "code", name: "Code", description: "Code mode" },
{ slug: "architect", name: "Architect", description: "Architect mode" },
]),
getProviderProfile: vi.fn().mockResolvedValue("default"),
getProviderProfiles: vi.fn().mockResolvedValue([{ name: "default", description: "Default profile" }]),
on: vi.fn((event: keyof TaskProviderEvents, listener: (...args: unknown[]) => unknown) => {
if (!eventListeners.has(event)) {
eventListeners.set(event, new Set())
Expand Down Expand Up @@ -184,6 +191,9 @@ describe("ExtensionChannel", () => {
// Connect the socket to enable publishing
await extensionChannel.onConnect(mockSocket)

// Clear the mock calls from the connection (which emits a register event)
;(mockSocket.emit as any).mockClear()

// Get a listener that was registered for TaskStarted
const taskStartedListeners = eventListeners.get(RooCodeEventName.TaskStarted)
expect(taskStartedListeners).toBeDefined()
Expand All @@ -192,7 +202,7 @@ describe("ExtensionChannel", () => {
// Trigger the listener
const listener = Array.from(taskStartedListeners!)[0]
if (listener) {
listener("test-task-id")
await listener("test-task-id")
}

// Verify the event was published to the socket
Expand Down
7 changes: 6 additions & 1 deletion packages/cloud/src/bridge/__tests__/TaskChannel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,12 @@ describe("TaskChannel", () => {

taskChannel.handleCommand(command)

expect(mockTask.submitUserMessage).toHaveBeenCalledWith(command.payload.text, command.payload.images)
expect(mockTask.submitUserMessage).toHaveBeenCalledWith(
command.payload.text,
command.payload.images,
undefined,
undefined,
)
})

it("should handle ApproveAsk command", () => {
Expand Down
2 changes: 1 addition & 1 deletion packages/types/npm/package.metadata.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@roo-code/types",
"version": "1.65.0",
"version": "1.66.0",
"description": "TypeScript type definitions for Roo Code.",
"publishConfig": {
"access": "public",
Expand Down
27 changes: 24 additions & 3 deletions packages/types/src/cloud.ts
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ export const extensionInstanceSchema = z.object({
task: extensionTaskSchema,
taskAsk: clineMessageSchema.optional(),
taskHistory: z.array(z.string()),
mode: z.string().optional(),
modes: z.array(z.object({ slug: z.string(), name: z.string() })).optional(),
providerProfile: z.string().optional(),
providerProfiles: z.array(z.object({ name: z.string(), provider: z.string().optional() })).optional(),
})

export type ExtensionInstance = z.infer<typeof extensionInstanceSchema>
Expand All @@ -398,6 +402,9 @@ export enum ExtensionBridgeEventName {
TaskResumable = RooCodeEventName.TaskResumable,
TaskIdle = RooCodeEventName.TaskIdle,

ModeChanged = RooCodeEventName.ModeChanged,
ProviderProfileChanged = RooCodeEventName.ProviderProfileChanged,

InstanceRegistered = "instance_registered",
InstanceUnregistered = "instance_unregistered",
HeartbeatUpdated = "heartbeat_updated",
Expand Down Expand Up @@ -469,6 +476,18 @@ export const extensionBridgeEventSchema = z.discriminatedUnion("type", [
instance: extensionInstanceSchema,
timestamp: z.number(),
}),
z.object({
type: z.literal(ExtensionBridgeEventName.ModeChanged),
instance: extensionInstanceSchema,
mode: z.string(),
timestamp: z.number(),
}),
z.object({
type: z.literal(ExtensionBridgeEventName.ProviderProfileChanged),
instance: extensionInstanceSchema,
providerProfile: z.object({ name: z.string(), provider: z.string().optional() }),
timestamp: z.number(),
}),
])

export type ExtensionBridgeEvent = z.infer<typeof extensionBridgeEventSchema>
Expand All @@ -490,6 +509,8 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
payload: z.object({
text: z.string(),
images: z.array(z.string()).optional(),
mode: z.string().optional(),
providerProfile: z.string().optional(),
}),
timestamp: z.number(),
}),
Expand All @@ -502,9 +523,7 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal(ExtensionBridgeCommandName.ResumeTask),
instanceId: z.string(),
payload: z.object({
taskId: z.string(),
}),
payload: z.object({ taskId: z.string() }),
timestamp: z.number(),
}),
])
Expand Down Expand Up @@ -558,6 +577,8 @@ export const taskBridgeCommandSchema = z.discriminatedUnion("type", [
payload: z.object({
text: z.string(),
images: z.array(z.string()).optional(),
mode: z.string().optional(),
providerProfile: z.string().optional(),
}),
timestamp: z.number(),
}),
Expand Down
7 changes: 7 additions & 0 deletions packages/types/src/events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ export enum RooCodeEventName {
TaskTokenUsageUpdated = "taskTokenUsageUpdated",
TaskToolFailed = "taskToolFailed",

// Configuration Changes
ModeChanged = "modeChanged",
ProviderProfileChanged = "providerProfileChanged",

// Evals
EvalPass = "evalPass",
EvalFail = "evalFail",
Expand Down Expand Up @@ -81,6 +85,9 @@ export const rooCodeEventsSchema = z.object({

[RooCodeEventName.TaskToolFailed]: z.tuple([z.string(), toolNamesSchema, z.string()]),
[RooCodeEventName.TaskTokenUsageUpdated]: z.tuple([z.string(), tokenUsageSchema]),

[RooCodeEventName.ModeChanged]: z.tuple([z.string()]),
[RooCodeEventName.ProviderProfileChanged]: z.tuple([z.object({ name: z.string(), provider: z.string() })]),
})

export type RooCodeEvents = z.infer<typeof rooCodeEventsSchema>
Expand Down
4 changes: 3 additions & 1 deletion packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,11 @@ export const providerSettingsSchema = z.object({
export type ProviderSettings = z.infer<typeof providerSettingsSchema>

export const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() })

export const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and(
z.object({ id: z.string().optional() }),
)

export type ProviderSettingsWithId = z.infer<typeof providerSettingsWithIdSchema>

export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options
Expand Down Expand Up @@ -461,7 +463,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str
return "anthropic"
}

// Vercel AI Gateway uses anthropic protocol for anthropic models
// Vercel AI Gateway uses anthropic protocol for anthropic models.
if (provider && provider === "vercel-ai-gateway" && modelId && modelId.toLowerCase().startsWith("anthropic/")) {
return "anthropic"
}
Expand Down
Loading
Loading