Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/cloud/src/CloudAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,14 @@ export class CloudAPI {
.parse(data),
})
}

async creditBalance(): Promise<number> {
return this.request("/api/extension/credit-balance", {
method: "GET",
parseResponse: (data) => {
const result = z.object({ balance: z.number() }).parse(data)
return result.balance
},
})
}
}
96 changes: 96 additions & 0 deletions packages/cloud/src/__tests__/CloudAPI.creditBalance.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { describe, it, expect, vi, beforeEach, type Mock } from "vitest"
import { CloudAPI } from "../CloudAPI.js"
import { AuthenticationError, CloudAPIError } from "../errors.js"
import type { AuthService } from "@roo-code/types"

// Mock the config module
vi.mock("../config.js", () => ({
getRooCodeApiUrl: () => "https://api.test.com",
}))

// Mock the utils module
vi.mock("../utils.js", () => ({
getUserAgent: () => "test-user-agent",
}))

describe("CloudAPI.creditBalance", () => {
let mockAuthService: {
getSessionToken: Mock<() => string | undefined>
}
let cloudAPI: CloudAPI

beforeEach(() => {
mockAuthService = {
getSessionToken: vi.fn(),
}
cloudAPI = new CloudAPI(mockAuthService as unknown as AuthService)

// Reset fetch mock
global.fetch = vi.fn()
})

it("should fetch credit balance successfully", async () => {
const mockBalance = 12.34
mockAuthService.getSessionToken.mockReturnValue("test-session-token")

global.fetch = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ balance: mockBalance }),
})

const balance = await cloudAPI.creditBalance()

expect(balance).toBe(mockBalance)
expect(global.fetch).toHaveBeenCalledWith(
"https://api.test.com/api/extension/credit-balance",
expect.objectContaining({
method: "GET",
headers: expect.objectContaining({
Authorization: "Bearer test-session-token",
"Content-Type": "application/json",
"User-Agent": "test-user-agent",
}),
}),
)
})

it("should throw AuthenticationError when session token is missing", async () => {
mockAuthService.getSessionToken.mockReturnValue(undefined)

await expect(cloudAPI.creditBalance()).rejects.toThrow(AuthenticationError)
})

it("should handle API errors", async () => {
mockAuthService.getSessionToken.mockReturnValue("test-session-token")

global.fetch = vi.fn().mockResolvedValue({
ok: false,
status: 500,
statusText: "Internal Server Error",
json: async () => ({ error: "Server error" }),
})

await expect(cloudAPI.creditBalance()).rejects.toThrow(CloudAPIError)
})

it("should handle network errors", async () => {
mockAuthService.getSessionToken.mockReturnValue("test-session-token")

global.fetch = vi.fn().mockRejectedValue(new TypeError("fetch failed"))

await expect(cloudAPI.creditBalance()).rejects.toThrow(
"Network error while calling /api/extension/credit-balance",
)
})

it("should handle invalid response format", async () => {
mockAuthService.getSessionToken.mockReturnValue("test-session-token")

global.fetch = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ invalid: "response" }),
})

await expect(cloudAPI.creditBalance()).rejects.toThrow()
})
})
117 changes: 117 additions & 0 deletions src/core/webview/__tests__/webviewMessageHandler.rooBalance.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { webviewMessageHandler } from "../webviewMessageHandler"
import { CloudService } from "@roo-code/cloud"

vi.mock("@roo-code/cloud", () => ({
CloudService: {
hasInstance: vi.fn(),
instance: {
cloudAPI: {
creditBalance: vi.fn(),
},
},
},
}))

describe("webviewMessageHandler - requestRooCreditBalance", () => {
let mockProvider: any

beforeEach(() => {
mockProvider = {
postMessageToWebview: vi.fn(),
contextProxy: {
getValue: vi.fn(),
setValue: vi.fn(),
},
getCurrentTask: vi.fn(),
cwd: "/test/path",
}

vi.clearAllMocks()
})

it("should handle requestRooCreditBalance and return balance", async () => {
const mockBalance = 42.75
const requestId = "test-request-id"

;(CloudService.hasInstance as any).mockReturnValue(true)
;(CloudService.instance.cloudAPI!.creditBalance as any).mockResolvedValue(mockBalance)

await webviewMessageHandler(
mockProvider as any,
{
type: "requestRooCreditBalance",
requestId,
} as any,
)

expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "rooCreditBalance",
requestId,
values: { balance: mockBalance },
})
})

it("should handle CloudAPI errors", async () => {
const requestId = "test-request-id"
const errorMessage = "Failed to fetch balance"

;(CloudService.hasInstance as any).mockReturnValue(true)
;(CloudService.instance.cloudAPI!.creditBalance as any).mockRejectedValue(new Error(errorMessage))

await webviewMessageHandler(
mockProvider as any,
{
type: "requestRooCreditBalance",
requestId,
} as any,
)

expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "rooCreditBalance",
requestId,
values: { error: errorMessage },
})
})

it("should handle missing CloudService", async () => {
const requestId = "test-request-id"

;(CloudService.hasInstance as any).mockReturnValue(false)

await webviewMessageHandler(
mockProvider as any,
{
type: "requestRooCreditBalance",
requestId,
} as any,
)

expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "rooCreditBalance",
requestId,
values: { error: "Cloud service not available" },
})
})

it("should handle missing cloudAPI", async () => {
const requestId = "test-request-id"

;(CloudService.hasInstance as any).mockReturnValue(true)
;(CloudService.instance as any).cloudAPI = null

await webviewMessageHandler(
mockProvider as any,
{
type: "requestRooCreditBalance",
requestId,
} as any,
)

expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "rooCreditBalance",
requestId,
values: { error: "Cloud service not available" },
})
})
})
25 changes: 25 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,31 @@ export const webviewMessageHandler = async (
}
break
}
case "requestRooCreditBalance": {
// Fetch Roo credit balance using CloudAPI
const requestId = message.requestId
try {
if (!CloudService.hasInstance() || !CloudService.instance.cloudAPI) {
throw new Error("Cloud service not available")
}

const balance = await CloudService.instance.cloudAPI.creditBalance()

provider.postMessageToWebview({
type: "rooCreditBalance",
requestId,
values: { balance },
})
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
provider.postMessageToWebview({
type: "rooCreditBalance",
requestId,
values: { error: errorMessage },
})
}
break
}
case "requestOpenAiModels":
if (message?.values?.baseUrl && message?.values?.apiKey) {
const openAiModels = await getOpenAiModels(
Expand Down
1 change: 1 addition & 0 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export interface ExtensionMessage {
| "authenticatedUser"
| "condenseTaskContextResponse"
| "singleRouterModelFetchResponse"
| "rooCreditBalance"
| "indexingStatusUpdate"
| "indexCleared"
| "codebaseIndexConfig"
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export interface WebviewMessage {
| "requestOllamaModels"
| "requestLmStudioModels"
| "requestRooModels"
| "requestRooCreditBalance"
| "requestVsCodeLmModels"
| "requestHuggingFaceModels"
| "openImage"
Expand Down
17 changes: 11 additions & 6 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ import { TemperatureControl } from "./TemperatureControl"
import { RateLimitSecondsControl } from "./RateLimitSecondsControl"
import { ConsecutiveMistakeLimitControl } from "./ConsecutiveMistakeLimitControl"
import { BedrockCustomArn } from "./providers/BedrockCustomArn"
import { RooBalanceDisplay } from "./providers/RooBalanceDisplay"
import { buildDocLink } from "@src/utils/docLinks"

export interface ApiOptionsProps {
Expand Down Expand Up @@ -460,12 +461,16 @@ const ApiOptions = ({
<div className="flex flex-col gap-1 relative">
<div className="flex justify-between items-center">
<label className="block font-medium mb-1">{t("settings:providers.apiProvider")}</label>
{docs && (
<div className="text-xs text-vscode-descriptionForeground">
<VSCodeLink href={docs.url} className="hover:text-vscode-foreground" target="_blank">
{t("settings:providers.providerDocumentation", { provider: docs.name })}
</VSCodeLink>
</div>
{selectedProvider === "roo" && cloudIsAuthenticated ? (
<RooBalanceDisplay />
) : (
docs && (
<div className="text-xs text-vscode-descriptionForeground">
<VSCodeLink href={docs.url} className="hover:text-vscode-foreground" target="_blank">
{t("settings:providers.providerDocumentation", { provider: docs.name })}
</VSCodeLink>
</div>
)
)}
</div>
<SearchableSelect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query"

import { type ModelInfo, type ProviderSettings, openAiModelInfoSaneDefaults } from "@roo-code/types"

import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext"
import * as ExtensionStateContext from "@src/context/ExtensionStateContext"
const { ExtensionStateContextProvider } = ExtensionStateContext

import ApiOptions, { ApiOptionsProps } from "../ApiOptions"

Expand Down Expand Up @@ -238,6 +239,18 @@ vi.mock("../providers/LiteLLM", () => ({
),
}))

// Mock Roo provider for tests
vi.mock("../providers/Roo", () => ({
Roo: ({ cloudIsAuthenticated }: any) => (
<div data-testid="roo-provider">{cloudIsAuthenticated ? "Authenticated" : "Not Authenticated"}</div>
),
}))

// Mock RooBalanceDisplay for tests
vi.mock("../providers/RooBalanceDisplay", () => ({
RooBalanceDisplay: () => <div data-testid="roo-balance-display">Balance: $10.00</div>,
}))

vi.mock("@src/components/ui/hooks/useSelectedModel", () => ({
useSelectedModel: vi.fn((apiConfiguration: ProviderSettings) => {
if (apiConfiguration.apiModelId?.includes("thinking")) {
Expand Down Expand Up @@ -563,4 +576,40 @@ describe("ApiOptions", () => {
expect(screen.queryByTestId("litellm-provider")).not.toBeInTheDocument()
})
})

describe("Roo provider tests", () => {
it("shows balance display when authenticated", () => {
// Mock useExtensionState to return authenticated state
const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState")
useExtensionStateMock.mockReturnValue({
cloudIsAuthenticated: true,
organizationAllowList: { providers: {} },
} as any)

renderApiOptions({
apiConfiguration: {
apiProvider: "roo",
},
})

expect(screen.getByTestId("roo-balance-display")).toBeInTheDocument()
})

it("does not show balance display when not authenticated", () => {
// Mock useExtensionState to return unauthenticated state
const useExtensionStateMock = vi.spyOn(ExtensionStateContext, "useExtensionState")
useExtensionStateMock.mockReturnValue({
cloudIsAuthenticated: false,
organizationAllowList: { providers: {} },
} as any)

renderApiOptions({
apiConfiguration: {
apiProvider: "roo",
},
})

expect(screen.queryByTestId("roo-balance-display")).not.toBeInTheDocument()
})
})
})
6 changes: 4 additions & 2 deletions webview-ui/src/components/settings/providers/Roo.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ export const Roo = ({
return (
<>
{cloudIsAuthenticated ? (
<div className="text-sm text-vscode-descriptionForeground">
{t("settings:providers.roo.authenticatedMessage")}
<div className="flex justify-between items-center mb-2">
<div className="text-sm text-vscode-descriptionForeground">
{t("settings:providers.roo.authenticatedMessage")}
</div>
</div>
) : (
<div className="flex flex-col gap-2">
Expand Down
Loading
Loading