From cf3871ab3f37d1a10088fcd9f299dbd9c304925d Mon Sep 17 00:00:00 2001 From: John Richmond <5629+jr@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:07:50 -0800 Subject: [PATCH 1/3] Add a RCC credit balance display --- packages/cloud/src/CloudAPI.ts | 10 ++ .../__tests__/CloudAPI.creditBalance.spec.ts | 96 ++++++++++++++ .../webviewMessageHandler.rooBalance.spec.ts | 117 ++++++++++++++++++ src/core/webview/webviewMessageHandler.ts | 25 ++++ src/shared/ExtensionMessage.ts | 1 + src/shared/WebviewMessage.ts | 1 + .../src/components/settings/providers/Roo.tsx | 8 +- .../settings/providers/RooBalanceDisplay.tsx | 22 ++++ .../__tests__/RooBalanceDisplay.spec.tsx | 108 ++++++++++++++++ .../ui/hooks/useRooCreditBalance.ts | 55 ++++++++ 10 files changed, 441 insertions(+), 2 deletions(-) create mode 100644 packages/cloud/src/__tests__/CloudAPI.creditBalance.spec.ts create mode 100644 src/core/webview/__tests__/webviewMessageHandler.rooBalance.spec.ts create mode 100644 webview-ui/src/components/settings/providers/RooBalanceDisplay.tsx create mode 100644 webview-ui/src/components/settings/providers/__tests__/RooBalanceDisplay.spec.tsx create mode 100644 webview-ui/src/components/ui/hooks/useRooCreditBalance.ts diff --git a/packages/cloud/src/CloudAPI.ts b/packages/cloud/src/CloudAPI.ts index d1c3f89c2b..239dc9b564 100644 --- a/packages/cloud/src/CloudAPI.ts +++ b/packages/cloud/src/CloudAPI.ts @@ -134,4 +134,14 @@ export class CloudAPI { .parse(data), }) } + + async creditBalance(): Promise { + return this.request("/api/extension/credit-balance", { + method: "GET", + parseResponse: (data) => { + const result = z.object({ balance: z.number() }).parse(data) + return result.balance + }, + }) + } } diff --git a/packages/cloud/src/__tests__/CloudAPI.creditBalance.spec.ts b/packages/cloud/src/__tests__/CloudAPI.creditBalance.spec.ts new file mode 100644 index 0000000000..67ab0cf3b9 --- /dev/null +++ b/packages/cloud/src/__tests__/CloudAPI.creditBalance.spec.ts @@ -0,0 +1,96 @@ +import { describe, it, expect, vi, beforeEach, type Mock } from "vitest" +import { CloudAPI } from "../CloudAPI.js" +import { AuthenticationError, CloudAPIError } from "../errors.js" +import type { AuthService } from "@roo-code/types" + +// Mock the config module +vi.mock("../config.js", () => ({ + getRooCodeApiUrl: () => "https://api.test.com", +})) + +// Mock the utils module +vi.mock("../utils.js", () => ({ + getUserAgent: () => "test-user-agent", +})) + +describe("CloudAPI.creditBalance", () => { + let mockAuthService: { + getSessionToken: Mock<() => string | undefined> + } + let cloudAPI: CloudAPI + + beforeEach(() => { + mockAuthService = { + getSessionToken: vi.fn(), + } + cloudAPI = new CloudAPI(mockAuthService as unknown as AuthService) + + // Reset fetch mock + global.fetch = vi.fn() + }) + + it("should fetch credit balance successfully", async () => { + const mockBalance = 12.34 + mockAuthService.getSessionToken.mockReturnValue("test-session-token") + + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ balance: mockBalance }), + }) + + const balance = await cloudAPI.creditBalance() + + expect(balance).toBe(mockBalance) + expect(global.fetch).toHaveBeenCalledWith( + "https://api.test.com/api/extension/credit-balance", + expect.objectContaining({ + method: "GET", + headers: expect.objectContaining({ + Authorization: "Bearer test-session-token", + "Content-Type": "application/json", + "User-Agent": "test-user-agent", + }), + }), + ) + }) + + it("should throw AuthenticationError when session token is missing", async () => { + mockAuthService.getSessionToken.mockReturnValue(undefined) + + await expect(cloudAPI.creditBalance()).rejects.toThrow(AuthenticationError) + }) + + it("should handle API errors", async () => { + mockAuthService.getSessionToken.mockReturnValue("test-session-token") + + global.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 500, + statusText: "Internal Server Error", + json: async () => ({ error: "Server error" }), + }) + + await expect(cloudAPI.creditBalance()).rejects.toThrow(CloudAPIError) + }) + + it("should handle network errors", async () => { + mockAuthService.getSessionToken.mockReturnValue("test-session-token") + + global.fetch = vi.fn().mockRejectedValue(new TypeError("fetch failed")) + + await expect(cloudAPI.creditBalance()).rejects.toThrow( + "Network error while calling /api/extension/credit-balance", + ) + }) + + it("should handle invalid response format", async () => { + mockAuthService.getSessionToken.mockReturnValue("test-session-token") + + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ invalid: "response" }), + }) + + await expect(cloudAPI.creditBalance()).rejects.toThrow() + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.rooBalance.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.rooBalance.spec.ts new file mode 100644 index 0000000000..1eec76b62c --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.rooBalance.spec.ts @@ -0,0 +1,117 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import { CloudService } from "@roo-code/cloud" + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + hasInstance: vi.fn(), + instance: { + cloudAPI: { + creditBalance: vi.fn(), + }, + }, + }, +})) + +describe("webviewMessageHandler - requestRooCreditBalance", () => { + let mockProvider: any + + beforeEach(() => { + mockProvider = { + postMessageToWebview: vi.fn(), + contextProxy: { + getValue: vi.fn(), + setValue: vi.fn(), + }, + getCurrentTask: vi.fn(), + cwd: "/test/path", + } + + vi.clearAllMocks() + }) + + it("should handle requestRooCreditBalance and return balance", async () => { + const mockBalance = 42.75 + const requestId = "test-request-id" + + ;(CloudService.hasInstance as any).mockReturnValue(true) + ;(CloudService.instance.cloudAPI!.creditBalance as any).mockResolvedValue(mockBalance) + + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRooCreditBalance", + requestId, + } as any, + ) + + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "rooCreditBalance", + requestId, + values: { balance: mockBalance }, + }) + }) + + it("should handle CloudAPI errors", async () => { + const requestId = "test-request-id" + const errorMessage = "Failed to fetch balance" + + ;(CloudService.hasInstance as any).mockReturnValue(true) + ;(CloudService.instance.cloudAPI!.creditBalance as any).mockRejectedValue(new Error(errorMessage)) + + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRooCreditBalance", + requestId, + } as any, + ) + + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "rooCreditBalance", + requestId, + values: { error: errorMessage }, + }) + }) + + it("should handle missing CloudService", async () => { + const requestId = "test-request-id" + + ;(CloudService.hasInstance as any).mockReturnValue(false) + + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRooCreditBalance", + requestId, + } as any, + ) + + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "rooCreditBalance", + requestId, + values: { error: "Cloud service not available" }, + }) + }) + + it("should handle missing cloudAPI", async () => { + const requestId = "test-request-id" + + ;(CloudService.hasInstance as any).mockReturnValue(true) + ;(CloudService.instance as any).cloudAPI = null + + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRooCreditBalance", + requestId, + } as any, + ) + + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "rooCreditBalance", + requestId, + values: { error: "Cloud service not available" }, + }) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index b7da941b43..8f89a9ec51 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1006,6 +1006,31 @@ export const webviewMessageHandler = async ( } break } + case "requestRooCreditBalance": { + // Fetch Roo credit balance using CloudAPI + const requestId = message.requestId + try { + if (!CloudService.hasInstance() || !CloudService.instance.cloudAPI) { + throw new Error("Cloud service not available") + } + + const balance = await CloudService.instance.cloudAPI.creditBalance() + + provider.postMessageToWebview({ + type: "rooCreditBalance", + requestId, + values: { balance }, + }) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + provider.postMessageToWebview({ + type: "rooCreditBalance", + requestId, + values: { error: errorMessage }, + }) + } + break + } case "requestOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { const openAiModels = await getOpenAiModels( diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 5b575c1bb7..59745b9cf9 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -112,6 +112,7 @@ export interface ExtensionMessage { | "authenticatedUser" | "condenseTaskContextResponse" | "singleRouterModelFetchResponse" + | "rooCreditBalance" | "indexingStatusUpdate" | "indexCleared" | "codebaseIndexConfig" diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 02f0876ad3..1d403f16ca 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -60,6 +60,7 @@ export interface WebviewMessage { | "requestOllamaModels" | "requestLmStudioModels" | "requestRooModels" + | "requestRooCreditBalance" | "requestVsCodeLmModels" | "requestHuggingFaceModels" | "openImage" diff --git a/webview-ui/src/components/settings/providers/Roo.tsx b/webview-ui/src/components/settings/providers/Roo.tsx index 7a504a4487..3c02db18dd 100644 --- a/webview-ui/src/components/settings/providers/Roo.tsx +++ b/webview-ui/src/components/settings/providers/Roo.tsx @@ -7,6 +7,7 @@ import { vscode } from "@src/utils/vscode" import { Button } from "@src/components/ui" import { ModelPicker } from "../ModelPicker" +import { RooBalanceDisplay } from "./RooBalanceDisplay" type RooProps = { apiConfiguration: ProviderSettings @@ -30,8 +31,11 @@ export const Roo = ({ return ( <> {cloudIsAuthenticated ? ( -
- {t("settings:providers.roo.authenticatedMessage")} +
+
+ {t("settings:providers.roo.authenticatedMessage")} +
+
) : (
diff --git a/webview-ui/src/components/settings/providers/RooBalanceDisplay.tsx b/webview-ui/src/components/settings/providers/RooBalanceDisplay.tsx new file mode 100644 index 0000000000..755e5a844b --- /dev/null +++ b/webview-ui/src/components/settings/providers/RooBalanceDisplay.tsx @@ -0,0 +1,22 @@ +import { VSCodeLink } from "@vscode/webview-ui-toolkit/react" + +import { useRooCreditBalance } from "@/components/ui/hooks/useRooCreditBalance" +import { useExtensionState } from "@src/context/ExtensionStateContext" + +export const RooBalanceDisplay = () => { + const { data: balance } = useRooCreditBalance() + const { cloudApiUrl } = useExtensionState() + + if (balance === null || balance === undefined) { + return null + } + + const formattedBalance = balance.toFixed(2) + const billingUrl = cloudApiUrl ? `${cloudApiUrl.replace(/\/$/, "")}/billing` : "https://app.roocode.com/billing" + + return ( + + ${formattedBalance} + + ) +} diff --git a/webview-ui/src/components/settings/providers/__tests__/RooBalanceDisplay.spec.tsx b/webview-ui/src/components/settings/providers/__tests__/RooBalanceDisplay.spec.tsx new file mode 100644 index 0000000000..140c96b804 --- /dev/null +++ b/webview-ui/src/components/settings/providers/__tests__/RooBalanceDisplay.spec.tsx @@ -0,0 +1,108 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { render, screen } from "@testing-library/react" +import { RooBalanceDisplay } from "../RooBalanceDisplay" + +// Mock the hooks +vi.mock("@/components/ui/hooks/useRooCreditBalance", () => ({ + useRooCreditBalance: vi.fn(), +})) + +vi.mock("@src/context/ExtensionStateContext", () => ({ + useExtensionState: vi.fn(), +})) + +import { useRooCreditBalance } from "@/components/ui/hooks/useRooCreditBalance" +import { useExtensionState } from "@src/context/ExtensionStateContext" + +describe("RooBalanceDisplay", () => { + beforeEach(() => { + vi.clearAllMocks() + ;(useExtensionState as any).mockReturnValue({ + cloudApiUrl: undefined, + }) + }) + + it("should render balance formatted to 2 decimal places", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: 12.34, + isLoading: false, + error: null, + }) + + render() + + expect(screen.getByText("$12.34")).toBeInTheDocument() + }) + + it("should format balance to 2 decimal places when value has 1 decimal", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: 7.8, + isLoading: false, + error: null, + }) + + render() + + expect(screen.getByText("$7.80")).toBeInTheDocument() + }) + + it("should format whole numbers with 2 decimal places", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: 5, + isLoading: false, + error: null, + }) + + render() + + expect(screen.getByText("$5.00")).toBeInTheDocument() + }) + + it("should return null when balance is null", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: null, + isLoading: false, + error: null, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it("should return null when balance is undefined", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: undefined, + isLoading: false, + error: null, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it("should return null when there is an error", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: null, + isLoading: false, + error: "Failed to fetch balance", + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it("should render when balance is zero", () => { + ;(useRooCreditBalance as any).mockReturnValue({ + data: 0, + isLoading: false, + error: null, + }) + + render() + + expect(screen.getByText("$0.00")).toBeInTheDocument() + }) +}) diff --git a/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts b/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts new file mode 100644 index 0000000000..2b45dccd21 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts @@ -0,0 +1,55 @@ +import { useEffect, useState } from "react" +import type { ExtensionMessage } from "@roo/ExtensionMessage" +import { vscode } from "@src/utils/vscode" + +/** + * Hook to fetch Roo Code Cloud credit balance + * Returns the balance in dollars or null if unavailable + */ +export const useRooCreditBalance = () => { + const [balance, setBalance] = useState(null) + const [isLoading, setIsLoading] = useState(false) + const [error, setError] = useState(null) + + useEffect(() => { + setIsLoading(true) + const requestId = `roo-balance-${Date.now()}` + + const handleMessage = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "rooCreditBalance" && message.requestId === requestId) { + window.removeEventListener("message", handleMessage) + + if (message.values?.balance !== undefined) { + setBalance(message.values.balance) + setError(null) + } else if (message.values?.error) { + setError(message.values.error) + setBalance(null) + } + + setIsLoading(false) + } + } + + window.addEventListener("message", handleMessage) + + // Request the balance from the extension + vscode.postMessage({ type: "requestRooCreditBalance", requestId }) + + // Cleanup timeout + const timeout = setTimeout(() => { + window.removeEventListener("message", handleMessage) + setIsLoading(false) + setError("Request timed out") + }, 10000) // 10 second timeout + + return () => { + window.removeEventListener("message", handleMessage) + clearTimeout(timeout) + } + }, []) + + return { data: balance, isLoading, error } +} From 4c2e7a32dac0596cc9cc92f80f7c485184c0399d Mon Sep 17 00:00:00 2001 From: Matt Rubens Date: Wed, 19 Nov 2025 15:53:26 -0500 Subject: [PATCH 2/3] Replace the provider docs with the balance when logged in --- .../src/components/settings/ApiOptions.tsx | 17 ++++--- .../settings/__tests__/ApiOptions.spec.tsx | 51 ++++++++++++++++++- .../src/components/settings/providers/Roo.tsx | 2 - 3 files changed, 61 insertions(+), 9 deletions(-) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index dfb789cdda..adf312dea6 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -115,6 +115,7 @@ import { TemperatureControl } from "./TemperatureControl" import { RateLimitSecondsControl } from "./RateLimitSecondsControl" import { ConsecutiveMistakeLimitControl } from "./ConsecutiveMistakeLimitControl" import { BedrockCustomArn } from "./providers/BedrockCustomArn" +import { RooBalanceDisplay } from "./providers/RooBalanceDisplay" import { buildDocLink } from "@src/utils/docLinks" export interface ApiOptionsProps { @@ -460,12 +461,16 @@ const ApiOptions = ({
- {docs && ( -
- - {t("settings:providers.providerDocumentation", { provider: docs.name })} - -
+ {selectedProvider === "roo" && cloudIsAuthenticated ? ( + + ) : ( + docs && ( +
+ + {t("settings:providers.providerDocumentation", { provider: docs.name })} + +
+ ) )}
({ ), })) +// Mock Roo provider for tests +vi.mock("../providers/Roo", () => ({ + Roo: ({ cloudIsAuthenticated }: any) => ( +
{cloudIsAuthenticated ? "Authenticated" : "Not Authenticated"}
+ ), +})) + +// Mock RooBalanceDisplay for tests +vi.mock("../providers/RooBalanceDisplay", () => ({ + RooBalanceDisplay: () =>
Balance: $10.00
, +})) + vi.mock("@src/components/ui/hooks/useSelectedModel", () => ({ useSelectedModel: vi.fn((apiConfiguration: ProviderSettings) => { if (apiConfiguration.apiModelId?.includes("thinking")) { @@ -563,4 +576,40 @@ describe("ApiOptions", () => { expect(screen.queryByTestId("litellm-provider")).not.toBeInTheDocument() }) }) + + describe("Roo provider tests", () => { + it("shows balance display when authenticated", () => { + // Mock useExtensionState to return authenticated state + const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState") + useExtensionStateMock.mockReturnValue({ + cloudIsAuthenticated: true, + organizationAllowList: { providers: {} }, + } as any) + + renderApiOptions({ + apiConfiguration: { + apiProvider: "roo", + }, + }) + + expect(screen.getByTestId("roo-balance-display")).toBeInTheDocument() + }) + + it("does not show balance display when not authenticated", () => { + // Mock useExtensionState to return unauthenticated state + const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState") + useExtensionStateMock.mockReturnValue({ + cloudIsAuthenticated: false, + organizationAllowList: { providers: {} }, + } as any) + + renderApiOptions({ + apiConfiguration: { + apiProvider: "roo", + }, + }) + + expect(screen.queryByTestId("roo-balance-display")).not.toBeInTheDocument() + }) + }) }) diff --git a/webview-ui/src/components/settings/providers/Roo.tsx b/webview-ui/src/components/settings/providers/Roo.tsx index 3c02db18dd..3fc7f090d7 100644 --- a/webview-ui/src/components/settings/providers/Roo.tsx +++ b/webview-ui/src/components/settings/providers/Roo.tsx @@ -7,7 +7,6 @@ import { vscode } from "@src/utils/vscode" import { Button } from "@src/components/ui" import { ModelPicker } from "../ModelPicker" -import { RooBalanceDisplay } from "./RooBalanceDisplay" type RooProps = { apiConfiguration: ProviderSettings @@ -35,7 +34,6 @@ export const Roo = ({
{t("settings:providers.roo.authenticatedMessage")}
-
) : (
From fb7982b5f7382938ec25e9f99a22b6368cabc551 Mon Sep 17 00:00:00 2001 From: John Richmond <5629+jr@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:52:24 -0800 Subject: [PATCH 3/3] PR feedback --- .../src/components/ui/hooks/useRooCreditBalance.ts | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts b/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts index 2b45dccd21..86fe0236c2 100644 --- a/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts +++ b/webview-ui/src/components/ui/hooks/useRooCreditBalance.ts @@ -20,6 +20,7 @@ export const useRooCreditBalance = () => { if (message.type === "rooCreditBalance" && message.requestId === requestId) { window.removeEventListener("message", handleMessage) + clearTimeout(timeout) if (message.values?.balance !== undefined) { setBalance(message.values.balance) @@ -33,17 +34,15 @@ export const useRooCreditBalance = () => { } } - window.addEventListener("message", handleMessage) - - // Request the balance from the extension - vscode.postMessage({ type: "requestRooCreditBalance", requestId }) - - // Cleanup timeout const timeout = setTimeout(() => { window.removeEventListener("message", handleMessage) setIsLoading(false) setError("Request timed out") - }, 10000) // 10 second timeout + }, 10000) + + window.addEventListener("message", handleMessage) + + vscode.postMessage({ type: "requestRooCreditBalance", requestId }) return () => { window.removeEventListener("message", handleMessage)