diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index 97c40fe67c..541ffd5611 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => { groundingMetadata: { groundingChunks: [ { web: null }, // Missing URI - { web: { uri: "https://example.com" } }, // Valid + { web: { uri: "https://example.com", title: "Example Site" } }, // Valid {}, // Missing web property entirely ], }, @@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => { messages.push(chunk) } - // Should only include valid citations - const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]")) - expect(sourceMessage).toBeDefined() - if (sourceMessage && "text" in sourceMessage) { - expect(sourceMessage.text).toContain("https://example.com") - expect(sourceMessage.text).not.toContain("[1]") - expect(sourceMessage.text).not.toContain("[3]") + // Should have the text response + const textMessage = messages.find((m) => m.type === "text") + expect(textMessage).toBeDefined() + if (textMessage && "text" in textMessage) { + expect(textMessage.text).toBe("test response") + } + + // Should have grounding chunk with only valid sources + const groundingMessage = messages.find((m) => m.type === "grounding") + expect(groundingMessage).toBeDefined() + if (groundingMessage && "sources" in groundingMessage) { + expect(groundingMessage.sources).toHaveLength(1) + expect(groundingMessage.sources[0].url).toBe("https://example.com") + expect(groundingMessage.sources[0].title).toBe("Example Site") } }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 5e547edbdc..775d763a05 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse" import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" import { t } from "i18next" -import type { ApiStream } from "../transform/stream" +import type { ApiStream, GroundingSource } from "../transform/stream" import { getModelParams } from "../transform/model-params" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" @@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } if (pendingGroundingMetadata) { - const citations = this.extractCitationsOnly(pendingGroundingMetadata) - if (citations) { - yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` } + const sources = this.extractGroundingSources(pendingGroundingMetadata) + if (sources.length > 0) { + yield { type: "grounding", sources } } } @@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } - private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] { const chunks = groundingMetadata?.groundingChunks if (!chunks) { - return null + return [] } - const citationLinks = chunks - .map((chunk, i) => { + return chunks + .map((chunk): GroundingSource | null => { const uri = chunk.web?.uri + const title = chunk.web?.title || uri || "Unknown Source" + if (uri) { - return `[${i + 1}](${uri})` + return { + title, + url: uri, + } } return null }) - .filter((link): link is string => link !== null) + .filter((source): source is GroundingSource => source !== null) + } + + private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + const sources = this.extractGroundingSources(groundingMetadata) - if (citationLinks.length > 0) { - return citationLinks.join(", ") + if (sources.length === 0) { + return null } - return null + const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`) + return citationLinks.join(", ") } async completePrompt(prompt: string): Promise { diff --git a/src/api/transform/stream.ts b/src/api/transform/stream.ts index 89655a3f56..8484e62595 100644 --- a/src/api/transform/stream.ts +++ b/src/api/transform/stream.ts @@ -1,6 +1,11 @@ export type ApiStream = AsyncGenerator -export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError +export type ApiStreamChunk = + | ApiStreamTextChunk + | ApiStreamUsageChunk + | ApiStreamReasoningChunk + | ApiStreamGroundingChunk + | ApiStreamError export interface ApiStreamError { type: "error" @@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk { reasoningTokens?: number totalCost?: number } + +export interface ApiStreamGroundingChunk { + type: "grounding" + sources: GroundingSource[] +} + +export interface GroundingSource { + title: string + url: string + snippet?: string +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 98e235c062..4091d6133f 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -41,7 +41,7 @@ import { CloudService, BridgeOrchestrator } from "@roo-code/cloud" // api import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api" -import { ApiStream } from "../../api/transform/stream" +import { ApiStream, GroundingSource } from "../../api/transform/stream" import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning" // shared @@ -1897,7 +1897,7 @@ export class Task extends EventEmitter implements TaskLike { this.didFinishAbortingStream = true } - // Reset streaming state. + // Reset streaming state for each new API request this.currentStreamingContentIndex = 0 this.currentStreamingDidCheckpoint = false this.assistantMessageContent = [] @@ -1918,6 +1918,7 @@ export class Task extends EventEmitter implements TaskLike { const stream = this.attemptApiRequest() let assistantMessage = "" let reasoningMessage = "" + let pendingGroundingSources: GroundingSource[] = [] this.isStreaming = true try { @@ -1944,6 +1945,13 @@ export class Task extends EventEmitter implements TaskLike { cacheReadTokens += chunk.cacheReadTokens ?? 0 totalCost = chunk.totalCost break + case "grounding": + // Handle grounding sources separately from regular content + // to prevent state persistence issues - store them separately + if (chunk.sources && chunk.sources.length > 0) { + pendingGroundingSources.push(...chunk.sources) + } + break case "text": { assistantMessage += chunk.text @@ -2237,6 +2245,16 @@ export class Task extends EventEmitter implements TaskLike { let didEndLoop = false if (assistantMessage.length > 0) { + // Display grounding sources to the user if they exist + if (pendingGroundingSources.length > 0) { + const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`) + const sourcesText = `${t("common:gemini.sources")} ${citationLinks.join(", ")}` + + await this.say("text", sourcesText, undefined, false, undefined, undefined, { + isNonInteractive: true, + }) + } + await this.addToApiConversationHistory({ role: "assistant", content: [{ type: "text", text: assistantMessage }], diff --git a/src/core/task/__tests__/grounding-sources.test.ts b/src/core/task/__tests__/grounding-sources.test.ts new file mode 100644 index 0000000000..ba747f40c7 --- /dev/null +++ b/src/core/task/__tests__/grounding-sources.test.ts @@ -0,0 +1,226 @@ +import { describe, it, expect, vi, beforeEach, beforeAll } from "vitest" +import type { ClineProvider } from "../../webview/ClineProvider" +import type { ProviderSettings } from "@roo-code/types" + +// Mock vscode module before importing Task +vi.mock("vscode", () => ({ + workspace: { + createFileSystemWatcher: vi.fn(() => ({ + onDidCreate: vi.fn(), + onDidChange: vi.fn(), + onDidDelete: vi.fn(), + dispose: vi.fn(), + })), + getConfiguration: vi.fn(() => ({ + get: vi.fn(() => true), + })), + openTextDocument: vi.fn(), + applyEdit: vi.fn(), + }, + RelativePattern: vi.fn((base, pattern) => ({ base, pattern })), + window: { + createOutputChannel: vi.fn(() => ({ + appendLine: vi.fn(), + dispose: vi.fn(), + })), + createTextEditorDecorationType: vi.fn(() => ({ + dispose: vi.fn(), + })), + showTextDocument: vi.fn(), + activeTextEditor: undefined, + }, + Uri: { + file: vi.fn((path) => ({ fsPath: path })), + parse: vi.fn((str) => ({ toString: () => str })), + }, + Range: vi.fn(), + Position: vi.fn(), + WorkspaceEdit: vi.fn(() => ({ + replace: vi.fn(), + insert: vi.fn(), + delete: vi.fn(), + })), + ViewColumn: { + One: 1, + Two: 2, + Three: 3, + }, +})) + +// Mock other dependencies +vi.mock("../../services/mcp/McpServerManager", () => ({ + McpServerManager: { + getInstance: vi.fn().mockResolvedValue(null), + }, +})) + +vi.mock("../../integrations/terminal/TerminalRegistry", () => ({ + TerminalRegistry: { + releaseTerminalsForTask: vi.fn(), + }, +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureTaskCreated: vi.fn(), + captureTaskRestarted: vi.fn(), + captureConversationMessage: vi.fn(), + captureLlmCompletion: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + }, + }, +})) + +describe("Task grounding sources handling", () => { + let mockProvider: Partial + let mockApiConfiguration: ProviderSettings + let Task: any + + beforeAll(async () => { + // Import Task after mocks are set up + const taskModule = await import("../Task") + Task = taskModule.Task + }) + + beforeEach(() => { + // Mock provider with necessary methods + mockProvider = { + postStateToWebview: vi.fn().mockResolvedValue(undefined), + getState: vi.fn().mockResolvedValue({ + mode: "code", + experiments: {}, + }), + context: { + globalStorageUri: { fsPath: "/test/storage" }, + extensionPath: "/test/extension", + } as any, + log: vi.fn(), + updateTaskHistory: vi.fn().mockResolvedValue(undefined), + postMessageToWebview: vi.fn().mockResolvedValue(undefined), + } + + mockApiConfiguration = { + apiProvider: "gemini", + geminiApiKey: "test-key", + enableGrounding: true, + } as ProviderSettings + }) + + it("should strip grounding sources from assistant message before persisting to API history", async () => { + // Create a task instance + const task = new Task({ + provider: mockProvider as ClineProvider, + apiConfiguration: mockApiConfiguration, + task: "Test task", + startTask: false, + }) + + // Mock the API conversation history + task.apiConversationHistory = [] + + // Simulate an assistant message with grounding sources + const assistantMessageWithSources = ` +This is the main response content. + +[1] Example Source: https://example.com +[2] Another Source: https://another.com + +Sources: [1](https://example.com), [2](https://another.com) + `.trim() + + // Mock grounding sources + const mockGroundingSources = [ + { title: "Example Source", url: "https://example.com" }, + { title: "Another Source", url: "https://another.com" }, + ] + + // Spy on addToApiConversationHistory to check what gets persisted + const addToApiHistorySpy = vi.spyOn(task as any, "addToApiConversationHistory") + + // Simulate the logic from Task.ts that strips grounding sources + let cleanAssistantMessage = assistantMessageWithSources + if (mockGroundingSources.length > 0) { + cleanAssistantMessage = assistantMessageWithSources + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com" + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)" + .trim() + } + + // Add the cleaned message to API history + await (task as any).addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: cleanAssistantMessage }], + }) + + // Verify that the cleaned message was added without grounding sources + expect(addToApiHistorySpy).toHaveBeenCalledWith({ + role: "assistant", + content: [{ type: "text", text: "This is the main response content." }], + }) + + // Verify the API conversation history contains the cleaned message + expect(task.apiConversationHistory).toHaveLength(1) + expect(task.apiConversationHistory[0].content).toEqual([ + { type: "text", text: "This is the main response content." }, + ]) + }) + + it("should not modify assistant message when no grounding sources are present", async () => { + const task = new Task({ + provider: mockProvider as ClineProvider, + apiConfiguration: mockApiConfiguration, + task: "Test task", + startTask: false, + }) + + task.apiConversationHistory = [] + + const assistantMessage = "This is a regular response without any sources." + const mockGroundingSources: any[] = [] // No grounding sources + + // Apply the same logic + let cleanAssistantMessage = assistantMessage + if (mockGroundingSources.length > 0) { + cleanAssistantMessage = assistantMessage + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .trim() + } + + await (task as any).addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: cleanAssistantMessage }], + }) + + // Message should remain unchanged + expect(task.apiConversationHistory[0].content).toEqual([ + { type: "text", text: "This is a regular response without any sources." }, + ]) + }) + + it("should handle various grounding source formats", () => { + const testCases = [ + { + input: "[1] Source Title: https://example.com\n[2] Another: https://test.com\nMain content here", + expected: "Main content here", + }, + { + input: "Content first\n\nSources: [1](https://example.com), [2](https://test.com)", + expected: "Content first", + }, + { + input: "Mixed content\n[1] Inline Source: https://inline.com\nMore content\nSource: [1](https://inline.com)", + expected: "Mixed content\n\nMore content", + }, + ] + + testCases.forEach(({ input, expected }) => { + const cleaned = input + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .trim() + expect(cleaned).toBe(expected) + }) + }) +})