From a250a7eee9e2743ab7b1fd559ed9564290998db8 Mon Sep 17 00:00:00 2001 From: "Ton Hoang Nguyen (Bill)" <32552798+HahaBill@users.noreply.github.com> Date: Tue, 26 Aug 2025 23:24:58 +0100 Subject: [PATCH 1/3] feat: Tackling Race/State condition issue by Changing the Code Design - adding new type `ApiStreamGroundingChunk` to the stream type - collecting sources in the `Task.ts` instead -> decoupling --- src/api/providers/gemini.ts | 36 +++++++++++++++++++++++------------- src/api/transform/stream.ts | 18 +++++++++++++++++- src/core/task/Task.ts | 35 ++++++++++++++++++++++++++++++++--- 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 5e547edbdc6..775d763a05f 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse" import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" import { t } from "i18next" -import type { ApiStream } from "../transform/stream" +import type { ApiStream, GroundingSource } from "../transform/stream" import { getModelParams } from "../transform/model-params" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" @@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } if (pendingGroundingMetadata) { - const citations = this.extractCitationsOnly(pendingGroundingMetadata) - if (citations) { - yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` } + const sources = this.extractGroundingSources(pendingGroundingMetadata) + if (sources.length > 0) { + yield { type: "grounding", sources } } } @@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } - private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] { const chunks = groundingMetadata?.groundingChunks if (!chunks) { - return null + return [] } - const citationLinks = chunks - .map((chunk, i) => { + return chunks + .map((chunk): GroundingSource | null => { const uri = chunk.web?.uri + const title = chunk.web?.title || uri || "Unknown Source" + if (uri) { - return `[${i + 1}](${uri})` + return { + title, + url: uri, + } } return null }) - .filter((link): link is string => link !== null) + .filter((source): source is GroundingSource => source !== null) + } + + private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + const sources = this.extractGroundingSources(groundingMetadata) - if (citationLinks.length > 0) { - return citationLinks.join(", ") + if (sources.length === 0) { + return null } - return null + const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`) + return citationLinks.join(", ") } async completePrompt(prompt: string): Promise { diff --git a/src/api/transform/stream.ts b/src/api/transform/stream.ts index 89655a3f562..8484e625958 100644 --- a/src/api/transform/stream.ts +++ b/src/api/transform/stream.ts @@ -1,6 +1,11 @@ export type ApiStream = AsyncGenerator -export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError +export type ApiStreamChunk = + | ApiStreamTextChunk + | ApiStreamUsageChunk + | ApiStreamReasoningChunk + | ApiStreamGroundingChunk + | ApiStreamError export interface ApiStreamError { type: "error" @@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk { reasoningTokens?: number totalCost?: number } + +export interface ApiStreamGroundingChunk { + type: "grounding" + sources: GroundingSource[] +} + +export interface GroundingSource { + title: string + url: string + snippet?: string +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 104cb872067..2da7bc51737 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -39,7 +39,7 @@ import { CloudService, ExtensionBridgeService } from "@roo-code/cloud" // api import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api" -import { ApiStream } from "../../api/transform/stream" +import { ApiStream, GroundingSource } from "../../api/transform/stream" // shared import { findLastIndex } from "../../shared/array" @@ -1746,7 +1746,7 @@ export class Task extends EventEmitter implements TaskLike { this.didFinishAbortingStream = true } - // Reset streaming state. + // Reset streaming state for each new API request this.currentStreamingContentIndex = 0 this.currentStreamingDidCheckpoint = false this.assistantMessageContent = [] @@ -1767,6 +1767,7 @@ export class Task extends EventEmitter implements TaskLike { const stream = this.attemptApiRequest() let assistantMessage = "" let reasoningMessage = "" + let pendingGroundingSources: GroundingSource[] = [] this.isStreaming = true try { @@ -1793,6 +1794,13 @@ export class Task extends EventEmitter implements TaskLike { cacheReadTokens += chunk.cacheReadTokens ?? 0 totalCost = chunk.totalCost break + case "grounding": + // Handle grounding sources separately from regular content + // to prevent state persistence issues - store them separately + if (chunk.sources && chunk.sources.length > 0) { + pendingGroundingSources.push(...chunk.sources) + } + break case "text": { assistantMessage += chunk.text @@ -2086,9 +2094,30 @@ export class Task extends EventEmitter implements TaskLike { let didEndLoop = false if (assistantMessage.length > 0) { + // Display grounding sources to the user if they exist + if (pendingGroundingSources.length > 0) { + const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`) + const sourcesText = `Sources: ${citationLinks.join(", ")}` + + await this.say("text", sourcesText, undefined, false, undefined, undefined, { + isNonInteractive: true, + }) + } + + // Strip grounding sources from assistant message before persisting to API history + // This prevents state persistence issues while maintaining user experience + let cleanAssistantMessage = assistantMessage + if (pendingGroundingSources.length > 0) { + // Remove any grounding source references that might have been integrated into the message + cleanAssistantMessage = assistantMessage + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .trim() + } + await this.addToApiConversationHistory({ role: "assistant", - content: [{ type: "text", text: assistantMessage }], + content: [{ type: "text", text: cleanAssistantMessage }], }) TelemetryService.instance.captureConversationMessage(this.taskId, "assistant") From edd6f3d1e420aab6559335cb5f7354297fa5e3cf Mon Sep 17 00:00:00 2001 From: daniel-lxs Date: Tue, 26 Aug 2025 18:22:10 -0500 Subject: [PATCH 2/3] fix: update gemini grounding test to match new architecture - Fixed failing test to expect grounding chunk instead of text with citations - Added inline comments to regex patterns for source stripping in Task.ts - Added test coverage for grounding source handling to prevent regression --- .../__tests__/gemini-handler.spec.ts | 23 +- src/core/task/Task.ts | 4 +- .../task/__tests__/grounding-sources.test.ts | 226 ++++++++++++++++++ 3 files changed, 243 insertions(+), 10 deletions(-) create mode 100644 src/core/task/__tests__/grounding-sources.test.ts diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index 7c61639cfd6..c07f6d71c98 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => { groundingMetadata: { groundingChunks: [ { web: null }, // Missing URI - { web: { uri: "https://example.com" } }, // Valid + { web: { uri: "https://example.com", title: "Example Site" } }, // Valid {}, // Missing web property entirely ], }, @@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => { messages.push(chunk) } - // Should only include valid citations - const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]")) - expect(sourceMessage).toBeDefined() - if (sourceMessage && "text" in sourceMessage) { - expect(sourceMessage.text).toContain("https://example.com") - expect(sourceMessage.text).not.toContain("[1]") - expect(sourceMessage.text).not.toContain("[3]") + // Should have the text response + const textMessage = messages.find((m) => m.type === "text") + expect(textMessage).toBeDefined() + if (textMessage && "text" in textMessage) { + expect(textMessage.text).toBe("test response") + } + + // Should have grounding chunk with only valid sources + const groundingMessage = messages.find((m) => m.type === "grounding") + expect(groundingMessage).toBeDefined() + if (groundingMessage && "sources" in groundingMessage) { + expect(groundingMessage.sources).toHaveLength(1) + expect(groundingMessage.sources[0].url).toBe("https://example.com") + expect(groundingMessage.sources[0].title).toBe("Example Site") } }) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 2da7bc51737..38d2cba1b48 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -2110,8 +2110,8 @@ export class Task extends EventEmitter implements TaskLike { if (pendingGroundingSources.length > 0) { // Remove any grounding source references that might have been integrated into the message cleanAssistantMessage = assistantMessage - .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") - .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com" + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)" .trim() } diff --git a/src/core/task/__tests__/grounding-sources.test.ts b/src/core/task/__tests__/grounding-sources.test.ts new file mode 100644 index 00000000000..ba747f40c77 --- /dev/null +++ b/src/core/task/__tests__/grounding-sources.test.ts @@ -0,0 +1,226 @@ +import { describe, it, expect, vi, beforeEach, beforeAll } from "vitest" +import type { ClineProvider } from "../../webview/ClineProvider" +import type { ProviderSettings } from "@roo-code/types" + +// Mock vscode module before importing Task +vi.mock("vscode", () => ({ + workspace: { + createFileSystemWatcher: vi.fn(() => ({ + onDidCreate: vi.fn(), + onDidChange: vi.fn(), + onDidDelete: vi.fn(), + dispose: vi.fn(), + })), + getConfiguration: vi.fn(() => ({ + get: vi.fn(() => true), + })), + openTextDocument: vi.fn(), + applyEdit: vi.fn(), + }, + RelativePattern: vi.fn((base, pattern) => ({ base, pattern })), + window: { + createOutputChannel: vi.fn(() => ({ + appendLine: vi.fn(), + dispose: vi.fn(), + })), + createTextEditorDecorationType: vi.fn(() => ({ + dispose: vi.fn(), + })), + showTextDocument: vi.fn(), + activeTextEditor: undefined, + }, + Uri: { + file: vi.fn((path) => ({ fsPath: path })), + parse: vi.fn((str) => ({ toString: () => str })), + }, + Range: vi.fn(), + Position: vi.fn(), + WorkspaceEdit: vi.fn(() => ({ + replace: vi.fn(), + insert: vi.fn(), + delete: vi.fn(), + })), + ViewColumn: { + One: 1, + Two: 2, + Three: 3, + }, +})) + +// Mock other dependencies +vi.mock("../../services/mcp/McpServerManager", () => ({ + McpServerManager: { + getInstance: vi.fn().mockResolvedValue(null), + }, +})) + +vi.mock("../../integrations/terminal/TerminalRegistry", () => ({ + TerminalRegistry: { + releaseTerminalsForTask: vi.fn(), + }, +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureTaskCreated: vi.fn(), + captureTaskRestarted: vi.fn(), + captureConversationMessage: vi.fn(), + captureLlmCompletion: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + }, + }, +})) + +describe("Task grounding sources handling", () => { + let mockProvider: Partial + let mockApiConfiguration: ProviderSettings + let Task: any + + beforeAll(async () => { + // Import Task after mocks are set up + const taskModule = await import("../Task") + Task = taskModule.Task + }) + + beforeEach(() => { + // Mock provider with necessary methods + mockProvider = { + postStateToWebview: vi.fn().mockResolvedValue(undefined), + getState: vi.fn().mockResolvedValue({ + mode: "code", + experiments: {}, + }), + context: { + globalStorageUri: { fsPath: "/test/storage" }, + extensionPath: "/test/extension", + } as any, + log: vi.fn(), + updateTaskHistory: vi.fn().mockResolvedValue(undefined), + postMessageToWebview: vi.fn().mockResolvedValue(undefined), + } + + mockApiConfiguration = { + apiProvider: "gemini", + geminiApiKey: "test-key", + enableGrounding: true, + } as ProviderSettings + }) + + it("should strip grounding sources from assistant message before persisting to API history", async () => { + // Create a task instance + const task = new Task({ + provider: mockProvider as ClineProvider, + apiConfiguration: mockApiConfiguration, + task: "Test task", + startTask: false, + }) + + // Mock the API conversation history + task.apiConversationHistory = [] + + // Simulate an assistant message with grounding sources + const assistantMessageWithSources = ` +This is the main response content. + +[1] Example Source: https://example.com +[2] Another Source: https://another.com + +Sources: [1](https://example.com), [2](https://another.com) + `.trim() + + // Mock grounding sources + const mockGroundingSources = [ + { title: "Example Source", url: "https://example.com" }, + { title: "Another Source", url: "https://another.com" }, + ] + + // Spy on addToApiConversationHistory to check what gets persisted + const addToApiHistorySpy = vi.spyOn(task as any, "addToApiConversationHistory") + + // Simulate the logic from Task.ts that strips grounding sources + let cleanAssistantMessage = assistantMessageWithSources + if (mockGroundingSources.length > 0) { + cleanAssistantMessage = assistantMessageWithSources + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com" + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)" + .trim() + } + + // Add the cleaned message to API history + await (task as any).addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: cleanAssistantMessage }], + }) + + // Verify that the cleaned message was added without grounding sources + expect(addToApiHistorySpy).toHaveBeenCalledWith({ + role: "assistant", + content: [{ type: "text", text: "This is the main response content." }], + }) + + // Verify the API conversation history contains the cleaned message + expect(task.apiConversationHistory).toHaveLength(1) + expect(task.apiConversationHistory[0].content).toEqual([ + { type: "text", text: "This is the main response content." }, + ]) + }) + + it("should not modify assistant message when no grounding sources are present", async () => { + const task = new Task({ + provider: mockProvider as ClineProvider, + apiConfiguration: mockApiConfiguration, + task: "Test task", + startTask: false, + }) + + task.apiConversationHistory = [] + + const assistantMessage = "This is a regular response without any sources." + const mockGroundingSources: any[] = [] // No grounding sources + + // Apply the same logic + let cleanAssistantMessage = assistantMessage + if (mockGroundingSources.length > 0) { + cleanAssistantMessage = assistantMessage + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .trim() + } + + await (task as any).addToApiConversationHistory({ + role: "assistant", + content: [{ type: "text", text: cleanAssistantMessage }], + }) + + // Message should remain unchanged + expect(task.apiConversationHistory[0].content).toEqual([ + { type: "text", text: "This is a regular response without any sources." }, + ]) + }) + + it("should handle various grounding source formats", () => { + const testCases = [ + { + input: "[1] Source Title: https://example.com\n[2] Another: https://test.com\nMain content here", + expected: "Main content here", + }, + { + input: "Content first\n\nSources: [1](https://example.com), [2](https://test.com)", + expected: "Content first", + }, + { + input: "Mixed content\n[1] Inline Source: https://inline.com\nMore content\nSource: [1](https://inline.com)", + expected: "Mixed content\n\nMore content", + }, + ] + + testCases.forEach(({ input, expected }) => { + const cleaned = input + .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") + .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") + .trim() + expect(cleaned).toBe(expected) + }) + }) +}) From 7e9a8c4c00741a7b3343dc8855679818218575fe Mon Sep 17 00:00:00 2001 From: "Ton Hoang Nguyen (Bill)" <32552798+HahaBill@users.noreply.github.com> Date: Wed, 27 Aug 2025 15:19:55 +0100 Subject: [PATCH 3/3] fix: making changes from the code review --- src/core/task/Task.ts | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 38d2cba1b48..672b0d4b2e8 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -2097,27 +2097,16 @@ export class Task extends EventEmitter implements TaskLike { // Display grounding sources to the user if they exist if (pendingGroundingSources.length > 0) { const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`) - const sourcesText = `Sources: ${citationLinks.join(", ")}` + const sourcesText = `${t("common:gemini.sources")} ${citationLinks.join(", ")}` await this.say("text", sourcesText, undefined, false, undefined, undefined, { isNonInteractive: true, }) } - // Strip grounding sources from assistant message before persisting to API history - // This prevents state persistence issues while maintaining user experience - let cleanAssistantMessage = assistantMessage - if (pendingGroundingSources.length > 0) { - // Remove any grounding source references that might have been integrated into the message - cleanAssistantMessage = assistantMessage - .replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com" - .replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)" - .trim() - } - await this.addToApiConversationHistory({ role: "assistant", - content: [{ type: "text", text: cleanAssistantMessage }], + content: [{ type: "text", text: assistantMessage }], }) TelemetryService.instance.captureConversationMessage(this.taskId, "assistant")