diff --git a/src/core/checkpoints/__tests__/checkpoint.test.ts b/src/core/checkpoints/__tests__/checkpoint.test.ts new file mode 100644 index 0000000000..49b26a4c2d --- /dev/null +++ b/src/core/checkpoints/__tests__/checkpoint.test.ts @@ -0,0 +1,434 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import { Task } from "../../task/Task" +import { ClineProvider } from "../../webview/ClineProvider" +import { checkpointSave, checkpointRestore, checkpointDiff, getCheckpointService } from "../index" +import * as vscode from "vscode" + +// Mock vscode +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + createTextEditorDecorationType: vi.fn(() => ({})), + showInformationMessage: vi.fn(), + }, + Uri: { + file: vi.fn((path: string) => ({ fsPath: path })), + parse: vi.fn((uri: string) => ({ with: vi.fn(() => ({})) })), + }, + commands: { + executeCommand: vi.fn(), + }, +})) + +// Mock other dependencies +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureCheckpointCreated: vi.fn(), + captureCheckpointRestored: vi.fn(), + captureCheckpointDiffed: vi.fn(), + }, + }, +})) + +vi.mock("../../../utils/path", () => ({ + getWorkspacePath: vi.fn(() => "/test/workspace"), +})) + +vi.mock("../../../services/checkpoints") + +describe("Checkpoint functionality", () => { + let mockProvider: any + let mockTask: any + let mockCheckpointService: any + + beforeEach(async () => { + // Create mock checkpoint service + mockCheckpointService = { + isInitialized: true, + saveCheckpoint: vi.fn().mockResolvedValue({ commit: "test-commit-hash" }), + restoreCheckpoint: vi.fn().mockResolvedValue(undefined), + getDiff: vi.fn().mockResolvedValue([]), + on: vi.fn(), + initShadowGit: vi.fn().mockResolvedValue(undefined), + } + + // Create mock provider + mockProvider = { + context: { + globalStorageUri: { fsPath: "/test/storage" }, + }, + log: vi.fn(), + postMessageToWebview: vi.fn(), + postStateToWebview: vi.fn(), + cancelTask: vi.fn(), + } + + // Create mock task + mockTask = { + taskId: "test-task-id", + enableCheckpoints: true, + checkpointService: mockCheckpointService, + checkpointServiceInitializing: false, + providerRef: { + deref: () => mockProvider, + }, + clineMessages: [], + apiConversationHistory: [], + pendingUserMessageCheckpoint: undefined, + say: vi.fn().mockResolvedValue(undefined), + overwriteClineMessages: vi.fn(), + overwriteApiConversationHistory: vi.fn(), + combineMessages: vi.fn().mockReturnValue([]), + } + + // Update the mock to return our mockCheckpointService + const checkpointsModule = await import("../../../services/checkpoints") + vi.mocked(checkpointsModule.RepoPerTaskCheckpointService.create).mockReturnValue(mockCheckpointService) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe("checkpointSave", () => { + it("should wait for checkpoint service initialization before saving", async () => { + // Set up task with uninitialized service + mockCheckpointService.isInitialized = false + mockTask.checkpointService = mockCheckpointService + + // Simulate service initialization after a delay + setTimeout(() => { + mockCheckpointService.isInitialized = true + }, 100) + + // Call checkpointSave + const savePromise = checkpointSave(mockTask, true) + + // Wait for the save to complete + const result = await savePromise + + // saveCheckpoint should have been called + expect(mockCheckpointService.saveCheckpoint).toHaveBeenCalledWith( + expect.stringContaining("Task: test-task-id"), + { allowEmpty: true }, + ) + + // Result should contain the commit hash + expect(result).toEqual({ commit: "test-commit-hash" }) + + // Task should still have checkpoints enabled + expect(mockTask.enableCheckpoints).toBe(true) + }) + + it("should handle timeout when service doesn't initialize", async () => { + // Service never initializes + mockCheckpointService.isInitialized = false + + // Call checkpointSave with a task that has no checkpoint service + const taskWithNoService = { + ...mockTask, + checkpointService: undefined, + enableCheckpoints: false, + } + + const result = await checkpointSave(taskWithNoService, true) + + // Result should be undefined + expect(result).toBeUndefined() + + // saveCheckpoint should not have been called + expect(mockCheckpointService.saveCheckpoint).not.toHaveBeenCalled() + }) + + it("should preserve checkpoint data through message deletion flow", async () => { + // Initialize service + mockCheckpointService.isInitialized = true + mockTask.checkpointService = mockCheckpointService + + // Simulate saving checkpoint before user message + const checkpointResult = await checkpointSave(mockTask, true) + expect(checkpointResult).toEqual({ commit: "test-commit-hash" }) + + // Simulate setting pendingUserMessageCheckpoint + if (checkpointResult && "commit" in checkpointResult) { + mockTask.pendingUserMessageCheckpoint = { + hash: checkpointResult.commit, + timestamp: Date.now(), + type: "user_message", + } + } + + // Verify checkpoint data is preserved + expect(mockTask.pendingUserMessageCheckpoint).toBeDefined() + expect(mockTask.pendingUserMessageCheckpoint.hash).toBe("test-commit-hash") + + // Simulate message deletion and reinitialization + mockTask.clineMessages = [] + mockTask.checkpointService = mockCheckpointService // Keep service available + mockTask.checkpointServiceInitializing = false + + // Save checkpoint again after deletion + const newCheckpointResult = await checkpointSave(mockTask, true) + + // Should still work after reinitialization + expect(newCheckpointResult).toEqual({ commit: "test-commit-hash" }) + expect(mockTask.enableCheckpoints).toBe(true) + }) + + it("should handle errors gracefully and disable checkpoints", async () => { + mockCheckpointService.saveCheckpoint.mockRejectedValue(new Error("Save failed")) + + const result = await checkpointSave(mockTask) + + expect(result).toBeUndefined() + expect(mockTask.enableCheckpoints).toBe(false) + }) + }) + + describe("checkpointRestore", () => { + beforeEach(() => { + mockTask.clineMessages = [ + { ts: 1, say: "user", text: "Message 1" }, + { ts: 2, say: "assistant", text: "Message 2" }, + { ts: 3, say: "user", text: "Message 3" }, + ] + mockTask.apiConversationHistory = [ + { ts: 1, role: "user", content: [{ type: "text", text: "Message 1" }] }, + { ts: 2, role: "assistant", content: [{ type: "text", text: "Message 2" }] }, + { ts: 3, role: "user", content: [{ type: "text", text: "Message 3" }] }, + ] + }) + + it("should restore checkpoint for delete operation", async () => { + await checkpointRestore(mockTask, { + ts: 2, + commitHash: "abc123", + mode: "restore", + operation: "delete", + }) + + expect(mockCheckpointService.restoreCheckpoint).toHaveBeenCalledWith("abc123") + expect(mockTask.overwriteApiConversationHistory).toHaveBeenCalledWith([ + { ts: 1, role: "user", content: [{ type: "text", text: "Message 1" }] }, + ]) + expect(mockTask.overwriteClineMessages).toHaveBeenCalledWith([{ ts: 1, say: "user", text: "Message 1" }]) + expect(mockProvider.cancelTask).toHaveBeenCalled() + }) + + it("should restore checkpoint for edit operation", async () => { + await checkpointRestore(mockTask, { + ts: 2, + commitHash: "abc123", + mode: "restore", + operation: "edit", + }) + + expect(mockCheckpointService.restoreCheckpoint).toHaveBeenCalledWith("abc123") + expect(mockTask.overwriteApiConversationHistory).toHaveBeenCalledWith([ + { ts: 1, role: "user", content: [{ type: "text", text: "Message 1" }] }, + ]) + // For edit operation, should include the message being edited + expect(mockTask.overwriteClineMessages).toHaveBeenCalledWith([ + { ts: 1, say: "user", text: "Message 1" }, + { ts: 2, say: "assistant", text: "Message 2" }, + ]) + expect(mockProvider.cancelTask).toHaveBeenCalled() + }) + + it("should handle preview mode without modifying messages", async () => { + await checkpointRestore(mockTask, { + ts: 2, + commitHash: "abc123", + mode: "preview", + }) + + expect(mockCheckpointService.restoreCheckpoint).toHaveBeenCalledWith("abc123") + expect(mockTask.overwriteApiConversationHistory).not.toHaveBeenCalled() + expect(mockTask.overwriteClineMessages).not.toHaveBeenCalled() + expect(mockProvider.cancelTask).toHaveBeenCalled() + }) + + it("should handle missing message gracefully", async () => { + await checkpointRestore(mockTask, { + ts: 999, // Non-existent timestamp + commitHash: "abc123", + mode: "restore", + }) + + expect(mockCheckpointService.restoreCheckpoint).not.toHaveBeenCalled() + }) + + it("should disable checkpoints on error", async () => { + mockCheckpointService.restoreCheckpoint.mockRejectedValue(new Error("Restore failed")) + + await checkpointRestore(mockTask, { + ts: 2, + commitHash: "abc123", + mode: "restore", + }) + + expect(mockTask.enableCheckpoints).toBe(false) + expect(mockProvider.log).toHaveBeenCalledWith("[checkpointRestore] disabling checkpoints for this task") + }) + }) + + describe("checkpointDiff", () => { + beforeEach(() => { + mockTask.clineMessages = [ + { ts: 1, say: "user", text: "Message 1" }, + { ts: 2, say: "checkpoint_saved", text: "commit1" }, + { ts: 3, say: "user", text: "Message 2" }, + { ts: 4, say: "checkpoint_saved", text: "commit2" }, + ] + }) + + it("should show diff for full mode", async () => { + const mockChanges = [ + { + paths: { absolute: "/test/file.ts", relative: "file.ts" }, + content: { before: "old content", after: "new content" }, + }, + ] + mockCheckpointService.getDiff.mockResolvedValue(mockChanges) + + await checkpointDiff(mockTask, { + ts: 4, + commitHash: "commit2", + mode: "full", + }) + + expect(mockCheckpointService.getDiff).toHaveBeenCalledWith({ + from: undefined, + to: "commit2", + }) + expect(vscode.commands.executeCommand).toHaveBeenCalledWith( + "vscode.changes", + "Changes since task started", + expect.any(Array), + ) + }) + + it("should show diff for checkpoint mode with previous commit", async () => { + const mockChanges = [ + { + paths: { absolute: "/test/file.ts", relative: "file.ts" }, + content: { before: "old content", after: "new content" }, + }, + ] + mockCheckpointService.getDiff.mockResolvedValue(mockChanges) + + await checkpointDiff(mockTask, { + ts: 4, + previousCommitHash: "commit1", + commitHash: "commit2", + mode: "checkpoint", + }) + + expect(mockCheckpointService.getDiff).toHaveBeenCalledWith({ + from: "commit1", + to: "commit2", + }) + expect(vscode.commands.executeCommand).toHaveBeenCalledWith( + "vscode.changes", + "Changes since previous checkpoint", + expect.any(Array), + ) + }) + + it("should find previous checkpoint automatically in checkpoint mode", async () => { + const mockChanges = [ + { + paths: { absolute: "/test/file.ts", relative: "file.ts" }, + content: { before: "old content", after: "new content" }, + }, + ] + mockCheckpointService.getDiff.mockResolvedValue(mockChanges) + + await checkpointDiff(mockTask, { + ts: 4, + commitHash: "commit2", + mode: "checkpoint", + }) + + expect(mockCheckpointService.getDiff).toHaveBeenCalledWith({ + from: "commit1", // Should find the previous checkpoint + to: "commit2", + }) + }) + + it("should show information message when no changes found", async () => { + mockCheckpointService.getDiff.mockResolvedValue([]) + + await checkpointDiff(mockTask, { + ts: 4, + commitHash: "commit2", + mode: "full", + }) + + expect(vscode.window.showInformationMessage).toHaveBeenCalledWith("No changes found.") + expect(vscode.commands.executeCommand).not.toHaveBeenCalled() + }) + + it("should disable checkpoints on error", async () => { + mockCheckpointService.getDiff.mockRejectedValue(new Error("Diff failed")) + + await checkpointDiff(mockTask, { + ts: 4, + commitHash: "commit2", + mode: "full", + }) + + expect(mockTask.enableCheckpoints).toBe(false) + expect(mockProvider.log).toHaveBeenCalledWith("[checkpointDiff] disabling checkpoints for this task") + }) + }) + + describe("getCheckpointService", () => { + it("should return existing service if available", () => { + const service = getCheckpointService(mockTask) + expect(service).toBe(mockCheckpointService) + }) + + it("should return undefined if checkpoints are disabled", () => { + mockTask.enableCheckpoints = false + const service = getCheckpointService(mockTask) + expect(service).toBeUndefined() + }) + + it("should return undefined if service is still initializing", () => { + mockTask.checkpointService = undefined + mockTask.checkpointServiceInitializing = true + const service = getCheckpointService(mockTask) + expect(service).toBeUndefined() + }) + + it("should create new service if none exists", async () => { + mockTask.checkpointService = undefined + mockTask.checkpointServiceInitializing = false + + const service = getCheckpointService(mockTask) + + const checkpointsModule = await import("../../../services/checkpoints") + expect(vi.mocked(checkpointsModule.RepoPerTaskCheckpointService.create)).toHaveBeenCalledWith({ + taskId: "test-task-id", + workspaceDir: "/test/workspace", + shadowDir: "/test/storage", + log: expect.any(Function), + }) + }) + + it("should disable checkpoints if workspace path is not found", async () => { + const pathModule = await import("../../../utils/path") + vi.mocked(pathModule.getWorkspacePath).mockReturnValue(null as any) + + mockTask.checkpointService = undefined + mockTask.checkpointServiceInitializing = false + + const service = getCheckpointService(mockTask) + + expect(service).toBeUndefined() + expect(mockTask.enableCheckpoints).toBe(false) + }) + }) +}) diff --git a/src/core/checkpoints/index.ts b/src/core/checkpoints/index.ts index 02fb5dfc5a..96ebe1ee73 100644 --- a/src/core/checkpoints/index.ts +++ b/src/core/checkpoints/index.ts @@ -192,35 +192,36 @@ async function getInitializedCheckpointService( } export async function checkpointSave(cline: Task, force = false) { - const service = getCheckpointService(cline) - - if (!service) { - return - } + try { + // Use getInitializedCheckpointService to wait for initialization + const service = await getInitializedCheckpointService(cline) - if (!service.isInitialized) { - const provider = cline.providerRef.deref() - provider?.log("[checkpointSave] checkpoints didn't initialize in time, disabling checkpoints for this task") - cline.enableCheckpoints = false - return - } + if (!service) { + return + } - TelemetryService.instance.captureCheckpointCreated(cline.taskId) + TelemetryService.instance.captureCheckpointCreated(cline.taskId) - // Start the checkpoint process in the background. - return service.saveCheckpoint(`Task: ${cline.taskId}, Time: ${Date.now()}`, { allowEmpty: force }).catch((err) => { + // Start the checkpoint process in the background. + return await service.saveCheckpoint(`Task: ${cline.taskId}, Time: ${Date.now()}`, { allowEmpty: force }) + } catch (err) { console.error("[Task#checkpointSave] caught unexpected error, disabling checkpoints", err) cline.enableCheckpoints = false - }) + return undefined + } } export type CheckpointRestoreOptions = { ts: number commitHash: string mode: "preview" | "restore" + operation?: "delete" | "edit" // Optional to maintain backward compatibility } -export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: CheckpointRestoreOptions) { +export async function checkpointRestore( + cline: Task, + { ts, commitHash, mode, operation = "delete" }: CheckpointRestoreOptions, +) { const service = await getInitializedCheckpointService(cline) if (!service) { @@ -249,7 +250,10 @@ export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: C cline.combineMessages(deletedMessages), ) - await cline.overwriteClineMessages(cline.clineMessages.slice(0, index + 1)) + // For delete operations, exclude the checkpoint message itself + // For edit operations, include the checkpoint message (to be edited) + const endIndex = operation === "edit" ? index + 1 : index + await cline.overwriteClineMessages(cline.clineMessages.slice(0, endIndex)) // TODO: Verify that this is working as expected. await cline.say( diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 905e657b37..0960829101 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -81,6 +81,16 @@ export type ClineProviderEvents = { clineCreated: [cline: Task] } +interface PendingEditOperation { + messageTs: number + editedContent: string + images?: string[] + messageIndex: number + apiConversationHistoryIndex: number + timeoutId: NodeJS.Timeout + createdAt: number +} + class OrganizationAllowListViolationError extends Error { constructor(message: string) { super(message) @@ -109,6 +119,8 @@ export class ClineProvider protected mcpHub?: McpHub // Change from private to protected private marketplaceManager: MarketplaceManager private mdmService?: MdmService + private pendingOperations: Map = new Map() + private static readonly PENDING_OPERATION_TIMEOUT_MS = 30000 // 30 seconds public isViewLaunched = false public settingsImportedAt?: number @@ -242,6 +254,72 @@ export class ClineProvider await this.removeClineFromStack() } + // Pending Edit Operations Management + + /** + * Sets a pending edit operation with automatic timeout cleanup + */ + public setPendingEditOperation( + operationId: string, + editData: { + messageTs: number + editedContent: string + images?: string[] + messageIndex: number + apiConversationHistoryIndex: number + }, + ): void { + // Clear any existing operation with the same ID + this.clearPendingEditOperation(operationId) + + // Create timeout for automatic cleanup + const timeoutId = setTimeout(() => { + this.clearPendingEditOperation(operationId) + this.log(`[setPendingEditOperation] Automatically cleared stale pending operation: ${operationId}`) + }, ClineProvider.PENDING_OPERATION_TIMEOUT_MS) + + // Store the operation + this.pendingOperations.set(operationId, { + ...editData, + timeoutId, + createdAt: Date.now(), + }) + + this.log(`[setPendingEditOperation] Set pending operation: ${operationId}`) + } + + /** + * Gets a pending edit operation by ID + */ + private getPendingEditOperation(operationId: string): PendingEditOperation | undefined { + return this.pendingOperations.get(operationId) + } + + /** + * Clears a specific pending edit operation + */ + private clearPendingEditOperation(operationId: string): boolean { + const operation = this.pendingOperations.get(operationId) + if (operation) { + clearTimeout(operation.timeoutId) + this.pendingOperations.delete(operationId) + this.log(`[clearPendingEditOperation] Cleared pending operation: ${operationId}`) + return true + } + return false + } + + /** + * Clears all pending edit operations + */ + private clearAllPendingEditOperations(): void { + for (const [operationId, operation] of this.pendingOperations) { + clearTimeout(operation.timeoutId) + } + this.pendingOperations.clear() + this.log(`[clearAllPendingEditOperations] Cleared all pending operations`) + } + /* VSCode extensions use the disposable pattern to clean up resources when the sidebar/editor tab is closed by the user or system. This applies to event listening, commands, interacting with the UI, etc. - https://vscode-docs.readthedocs.io/en/stable/extensions/patterns-and-principles/ @@ -261,6 +339,10 @@ export class ClineProvider await this.removeClineFromStack() this.log("Cleared task") + // Clear all pending edit operations to prevent memory leaks + this.clearAllPendingEditOperations() + this.log("Cleared pending operations") + if (this.view && "dispose" in this.view) { this.view.dispose() this.log("Disposed webview") @@ -605,6 +687,50 @@ export class ClineProvider this.log( `[subtasks] ${cline.parentTask ? "child" : "parent"} task ${cline.taskId}.${cline.instanceId} instantiated`, ) + + // Check if there's a pending edit after checkpoint restoration + const operationId = `task-${cline.taskId}` + const pendingEdit = this.getPendingEditOperation(operationId) + if (pendingEdit) { + this.clearPendingEditOperation(operationId) // Clear the pending edit + + this.log(`[initClineWithHistoryItem] Processing pending edit after checkpoint restoration`) + + // Process the pending edit after a short delay to ensure the task is fully initialized + setTimeout(async () => { + try { + // Find the message index in the restored state + const { messageIndex, apiConversationHistoryIndex } = (() => { + const messageIndex = cline.clineMessages.findIndex((msg) => msg.ts === pendingEdit.messageTs) + const apiConversationHistoryIndex = cline.apiConversationHistory.findIndex( + (msg) => msg.ts === pendingEdit.messageTs, + ) + return { messageIndex, apiConversationHistoryIndex } + })() + + if (messageIndex !== -1) { + // Remove the target message and all subsequent messages + await cline.overwriteClineMessages(cline.clineMessages.slice(0, messageIndex)) + + if (apiConversationHistoryIndex !== -1) { + await cline.overwriteApiConversationHistory( + cline.apiConversationHistory.slice(0, apiConversationHistoryIndex), + ) + } + + // Process the edited message + await cline.handleWebviewAskResponse( + "messageResponse", + pendingEdit.editedContent, + pendingEdit.images, + ) + } + } catch (error) { + this.log(`[initClineWithHistoryItem] Error processing pending edit: ${error}`) + } + }, 100) // Small delay to ensure task is fully ready + } + return cline } diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 344b098816..d1d0755d51 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -46,6 +46,12 @@ vi.mock("axios", () => ({ vi.mock("../../../utils/safeWriteJson") +vi.mock("../../../utils/storage", () => ({ + getSettingsDirectoryPath: vi.fn().mockResolvedValue("/test/settings/path"), + getTaskDirectoryPath: vi.fn().mockResolvedValue("/test/task/path"), + getGlobalStoragePath: vi.fn().mockResolvedValue("/test/storage/path"), +})) + vi.mock("@modelcontextprotocol/sdk/types.js", () => ({ CallToolResultSchema: {}, ListResourcesResultSchema: {}, @@ -1173,8 +1179,8 @@ describe("ClineProvider", () => { const mockMessages = [ { ts: 1000, type: "say", say: "user_feedback" }, // User message 1 { ts: 2000, type: "say", say: "tool" }, // Tool message - { ts: 3000, type: "say", say: "text", value: 4000 }, // Message to delete - { ts: 4000, type: "say", say: "browser_action" }, // Response to delete + { ts: 3000, type: "say", say: "text" }, // Message before delete + { ts: 4000, type: "say", say: "browser_action" }, // Message to delete { ts: 5000, type: "say", say: "user_feedback" }, // Next user message { ts: 6000, type: "say", say: "user_feedback" }, // Final message ] as ClineMessage[] @@ -1210,22 +1216,28 @@ describe("ClineProvider", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 4000, + hasCheckpoint: false, }) // Simulate user confirming deletion through the dialog await messageHandler({ type: "deleteMessageConfirm", messageTs: 4000 }) // Verify only messages before the deleted message were kept - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0], mockMessages[1]]) + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([ + mockMessages[0], + mockMessages[1], + mockMessages[2], + ]) // Verify only API messages before the deleted message were kept expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([ mockApiHistory[0], mockApiHistory[1], + mockApiHistory[2], ]) - // Verify initClineWithHistoryItem was called - expect((provider as any).initClineWithHistoryItem).toHaveBeenCalledWith({ id: "test-task-id" }) + // initClineWithHistoryItem is only called when restoring checkpoints or aborting tasks + expect((provider as any).initClineWithHistoryItem).not.toHaveBeenCalled() }) test("handles case when no current task exists", async () => { @@ -1255,8 +1267,8 @@ describe("ClineProvider", () => { const mockMessages = [ { ts: 1000, type: "say", say: "user_feedback" }, // User message 1 { ts: 2000, type: "say", say: "tool" }, // Tool message - { ts: 3000, type: "say", say: "text", value: 4000 }, // Message to edit - { ts: 4000, type: "say", say: "browser_action" }, // Response to edit + { ts: 3000, type: "say", say: "text" }, // Message before edit + { ts: 4000, type: "say", say: "browser_action" }, // Message to edit { ts: 5000, type: "say", say: "user_feedback" }, // Next user message { ts: 6000, type: "say", say: "user_feedback" }, // Final message ] as ClineMessage[] @@ -1303,6 +1315,8 @@ describe("ClineProvider", () => { type: "showEditMessageDialog", messageTs: 4000, text: "Edited message content", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming edit through the dialog @@ -1313,12 +1327,17 @@ describe("ClineProvider", () => { }) // Verify correct messages were kept (only messages before the edited one) - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0], mockMessages[1]]) + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([ + mockMessages[0], + mockMessages[1], + mockMessages[2], + ]) // Verify correct API messages were kept (only messages before the edited one) expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([ mockApiHistory[0], mockApiHistory[1], + mockApiHistory[2], ]) // The new flow calls webviewMessageHandler recursively with askResponse @@ -2712,6 +2731,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 3000, text: "Edited message with preserved images", + hasCheckpoint: false, + images: undefined, }) // Simulate confirmation @@ -2721,9 +2742,9 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { text: "Edited message with preserved images", }) - // Verify messages were edited correctly - only the first message should remain - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0]]) - expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([{ ts: 1000 }]) + // Verify messages were edited correctly - messages up to the edited message should remain + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0], mockMessages[1]]) + expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([{ ts: 1000 }, { ts: 2000 }]) }) test("handles editing messages with file attachments", async () => { @@ -2764,6 +2785,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 3000, text: "Edited message with file attachment", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -2820,6 +2843,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 2000, text: "Edited message", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -2860,6 +2885,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 2000, text: "Edited message", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -2916,11 +2943,15 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 2000, text: "Edited message 1", + hasCheckpoint: false, + images: undefined, }) expect(mockPostMessage).toHaveBeenCalledWith({ type: "showEditMessageDialog", messageTs: 4000, text: "Edited message 2", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming both edits @@ -3106,6 +3137,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 5000, text: "Edited non-existent message", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -3146,6 +3179,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 5000, + hasCheckpoint: false, }) // Simulate user confirming the delete @@ -3197,6 +3231,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 2000, text: "Edited message", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -3236,6 +3272,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 2000, + hasCheckpoint: false, }) // Simulate user confirming the delete @@ -3289,6 +3326,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: 2000, text: largeEditedContent, + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit @@ -3331,18 +3370,23 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 3000, + hasCheckpoint: false, }) // Simulate user confirming the delete await messageHandler({ type: "deleteMessageConfirm", messageTs: 3000 }) - // Should handle large payloads without issues - expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0]]) - expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([{ ts: 1000 }]) + // Should handle large payloads without issues - keeps messages before the deleted one + expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0], mockMessages[1]]) + expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([{ ts: 1000 }, { ts: 2000 }]) }) }) describe("Error Messaging and User Feedback", () => { + beforeEach(async () => { + await provider.resolveWebviewView(mockWebviewView) + }) + // Note: Error messaging test removed as the implementation may not have proper error handling in place test("provides user feedback for successful operations", async () => { @@ -3369,6 +3413,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 2000, + hasCheckpoint: false, }) // Simulate user confirming the delete @@ -3376,7 +3421,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { // Verify successful operation completed expect(mockCline.overwriteClineMessages).toHaveBeenCalled() - expect(provider.initClineWithHistoryItem).toHaveBeenCalled() + // initClineWithHistoryItem is only called when restoring checkpoints or aborting tasks expect(vscode.window.showErrorMessage).not.toHaveBeenCalled() }) @@ -3442,6 +3487,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 1000, + hasCheckpoint: false, }) // Simulate user confirming the delete @@ -3492,6 +3538,8 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { type: "showEditMessageDialog", messageTs: futureTimestamp + 1000, text: "Edited future message", + hasCheckpoint: false, + images: undefined, }) // Simulate user confirming the edit diff --git a/src/core/webview/__tests__/checkpointRestoreHandler.spec.ts b/src/core/webview/__tests__/checkpointRestoreHandler.spec.ts new file mode 100644 index 0000000000..d321449f20 --- /dev/null +++ b/src/core/webview/__tests__/checkpointRestoreHandler.spec.ts @@ -0,0 +1,242 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { handleCheckpointRestoreOperation } from "../checkpointRestoreHandler" +import { saveTaskMessages } from "../../task-persistence" +import pWaitFor from "p-wait-for" +import * as vscode from "vscode" + +// Mock dependencies +vi.mock("../../task-persistence", () => ({ + saveTaskMessages: vi.fn(), +})) +vi.mock("p-wait-for") +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + }, +})) + +describe("checkpointRestoreHandler", () => { + let mockProvider: any + let mockCline: any + + beforeEach(() => { + vi.clearAllMocks() + + // Setup mock Cline instance + mockCline = { + taskId: "test-task-123", + abort: false, + abortTask: vi.fn(() => { + mockCline.abort = true + }), + checkpointRestore: vi.fn(), + clineMessages: [ + { ts: 1, type: "user", say: "user", text: "First message" }, + { ts: 2, type: "assistant", say: "assistant", text: "Response" }, + { + ts: 3, + type: "user", + say: "user", + text: "Checkpoint message", + checkpoint: { hash: "abc123" }, + }, + { ts: 4, type: "assistant", say: "assistant", text: "After checkpoint" }, + ], + } + + // Setup mock provider + mockProvider = { + getCurrentCline: vi.fn(() => mockCline), + postMessageToWebview: vi.fn(), + getTaskWithId: vi.fn(() => ({ + historyItem: { id: "test-task-123", messages: mockCline.clineMessages }, + })), + initClineWithHistoryItem: vi.fn(), + setPendingEditOperation: vi.fn(), + contextProxy: { + globalStorageUri: { fsPath: "/test/storage" }, + }, + } + + // Mock pWaitFor to resolve immediately + ;(pWaitFor as any).mockImplementation(async (condition: () => boolean) => { + // Simulate the condition being met + return Promise.resolve() + }) + }) + + describe("handleCheckpointRestoreOperation", () => { + it("should abort task before checkpoint restore for delete operations", async () => { + // Simulate a task that hasn't been aborted yet + mockCline.abort = false + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }) + + // Verify abortTask was called before checkpointRestore + expect(mockCline.abortTask).toHaveBeenCalled() + expect(mockCline.checkpointRestore).toHaveBeenCalled() + + // Verify the order of operations + const abortOrder = mockCline.abortTask.mock.invocationCallOrder[0] + const restoreOrder = mockCline.checkpointRestore.mock.invocationCallOrder[0] + expect(abortOrder).toBeLessThan(restoreOrder) + }) + + it("should not abort task if already aborted", async () => { + // Simulate a task that's already aborted + mockCline.abort = true + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }) + + // Verify abortTask was not called + expect(mockCline.abortTask).not.toHaveBeenCalled() + expect(mockCline.checkpointRestore).toHaveBeenCalled() + }) + + it("should handle edit operations with pending edit data", async () => { + const editData = { + editedContent: "Edited content", + images: ["image1.png"], + apiConversationHistoryIndex: 2, + } + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "edit", + editData, + }) + + // Verify abortTask was NOT called for edit operations + expect(mockCline.abortTask).not.toHaveBeenCalled() + + // Verify pending edit operation was set + expect(mockProvider.setPendingEditOperation).toHaveBeenCalledWith("task-test-task-123", { + messageTs: 3, + editedContent: "Edited content", + images: ["image1.png"], + messageIndex: 2, + apiConversationHistoryIndex: 2, + }) + + // Verify checkpoint restore was called with edit operation + expect(mockCline.checkpointRestore).toHaveBeenCalledWith({ + ts: 3, + commitHash: "abc123", + mode: "restore", + operation: "edit", + }) + }) + + it("should save messages after delete operation", async () => { + // Mock the checkpoint restore to simulate message deletion + mockCline.checkpointRestore.mockImplementation(async () => { + mockCline.clineMessages = mockCline.clineMessages.slice(0, 2) + }) + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }) + + // Verify saveTaskMessages was called + expect(saveTaskMessages).toHaveBeenCalledWith({ + messages: mockCline.clineMessages, + taskId: "test-task-123", + globalStoragePath: "/test/storage", + }) + + // Verify initClineWithHistoryItem was called + expect(mockProvider.initClineWithHistoryItem).toHaveBeenCalled() + }) + + it("should reinitialize task with correct history item after delete", async () => { + const expectedHistoryItem = { + id: "test-task-123", + messages: mockCline.clineMessages, + } + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }) + + // Verify getTaskWithId was called + expect(mockProvider.getTaskWithId).toHaveBeenCalledWith("test-task-123") + + // Verify initClineWithHistoryItem was called with the correct history item + expect(mockProvider.initClineWithHistoryItem).toHaveBeenCalledWith(expectedHistoryItem) + }) + + it("should not save messages or reinitialize for edit operation", async () => { + const editData = { + editedContent: "Edited content", + images: [], + apiConversationHistoryIndex: 2, + } + + await handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "edit", + editData, + }) + + // Verify saveTaskMessages was NOT called for edit operation + expect(saveTaskMessages).not.toHaveBeenCalled() + + // Verify initClineWithHistoryItem was NOT called for edit operation + expect(mockProvider.initClineWithHistoryItem).not.toHaveBeenCalled() + }) + + it("should handle errors gracefully", async () => { + // Mock checkpoint restore to throw an error + mockCline.checkpointRestore.mockRejectedValue(new Error("Checkpoint restore failed")) + + // The function should throw and show an error message + await expect( + handleCheckpointRestoreOperation({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }), + ).rejects.toThrow("Checkpoint restore failed") + + // Verify error message was shown + expect(vscode.window.showErrorMessage).toHaveBeenCalledWith( + "Error during checkpoint restore: Checkpoint restore failed", + ) + }) + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts new file mode 100644 index 0000000000..6a22632136 --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts @@ -0,0 +1,131 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import { saveTaskMessages } from "../../task-persistence" +import { handleCheckpointRestoreOperation } from "../checkpointRestoreHandler" + +// Mock dependencies +vi.mock("../../task-persistence") +vi.mock("../checkpointRestoreHandler") +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + }, + workspace: { + workspaceFolders: undefined, + }, +})) + +describe("webviewMessageHandler - checkpoint operations", () => { + let mockProvider: any + let mockCline: any + + beforeEach(() => { + vi.clearAllMocks() + + // Setup mock Cline instance + mockCline = { + taskId: "test-task-123", + clineMessages: [ + { ts: 1, type: "user", say: "user", text: "First message" }, + { ts: 2, type: "assistant", say: "checkpoint_saved", text: "abc123" }, + { ts: 3, type: "user", say: "user", text: "Message to delete" }, + { ts: 4, type: "assistant", say: "assistant", text: "After message" }, + ], + apiConversationHistory: [ + { ts: 1, role: "user", content: [{ type: "text", text: "First message" }] }, + { ts: 3, role: "user", content: [{ type: "text", text: "Message to delete" }] }, + { ts: 4, role: "assistant", content: [{ type: "text", text: "After message" }] }, + ], + checkpointRestore: vi.fn(), + overwriteClineMessages: vi.fn(), + overwriteApiConversationHistory: vi.fn(), + } + + // Setup mock provider + mockProvider = { + getCurrentCline: vi.fn(() => mockCline), + postMessageToWebview: vi.fn(), + getTaskWithId: vi.fn(() => ({ + historyItem: { id: "test-task-123", messages: mockCline.clineMessages }, + })), + initClineWithHistoryItem: vi.fn(), + setPendingEditOperation: vi.fn(), + contextProxy: { + globalStorageUri: { fsPath: "/test/storage" }, + }, + } + }) + + describe("delete operations with checkpoint restoration", () => { + it("should call handleCheckpointRestoreOperation for checkpoint deletes", async () => { + // Mock handleCheckpointRestoreOperation + ;(handleCheckpointRestoreOperation as any).mockResolvedValue(undefined) + + // Call the handler with delete confirmation + await webviewMessageHandler(mockProvider, { + type: "deleteMessageConfirm", + messageTs: 3, + restoreCheckpoint: true, + }) + + // Verify handleCheckpointRestoreOperation was called with correct parameters + expect(handleCheckpointRestoreOperation).toHaveBeenCalledWith({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "delete", + }) + }) + + it("should save messages for non-checkpoint deletes", async () => { + // Call the handler with delete confirmation (no checkpoint restoration) + await webviewMessageHandler(mockProvider, { + type: "deleteMessageConfirm", + messageTs: 2, + restoreCheckpoint: false, + }) + + // Verify saveTaskMessages was called + expect(saveTaskMessages).toHaveBeenCalledWith({ + messages: expect.any(Array), + taskId: "test-task-123", + globalStoragePath: "/test/storage", + }) + + // Verify checkpoint restore was NOT called + expect(mockCline.checkpointRestore).not.toHaveBeenCalled() + }) + }) + + describe("edit operations with checkpoint restoration", () => { + it("should call handleCheckpointRestoreOperation for checkpoint edits", async () => { + // Mock handleCheckpointRestoreOperation + ;(handleCheckpointRestoreOperation as any).mockResolvedValue(undefined) + + // Call the handler with edit confirmation + await webviewMessageHandler(mockProvider, { + type: "editMessageConfirm", + messageTs: 3, + text: "Edited checkpoint message", + restoreCheckpoint: true, + }) + + // Verify handleCheckpointRestoreOperation was called with correct parameters + expect(handleCheckpointRestoreOperation).toHaveBeenCalledWith({ + provider: mockProvider, + currentCline: mockCline, + messageTs: 3, + messageIndex: 2, + checkpoint: { hash: "abc123" }, + operation: "edit", + editData: { + editedContent: "Edited checkpoint message", + images: undefined, + apiConversationHistoryIndex: 1, + }, + }) + }) + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 284ee98944..f2052f5533 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -502,7 +502,10 @@ describe("webviewMessageHandler - message dialog preferences", () => { describe("deleteMessage", () => { it("should always show dialog for delete confirmation", async () => { - vi.mocked(mockClineProvider.getCurrentCline).mockReturnValue({} as any) // Mock current cline exists + vi.mocked(mockClineProvider.getCurrentCline).mockReturnValue({ + clineMessages: [], + apiConversationHistory: [], + } as any) // Mock current cline with proper structure await webviewMessageHandler(mockClineProvider, { type: "deleteMessage", @@ -512,13 +515,17 @@ describe("webviewMessageHandler - message dialog preferences", () => { expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "showDeleteMessageDialog", messageTs: 123456789, + hasCheckpoint: false, }) }) }) describe("submitEditedMessage", () => { it("should always show dialog for edit confirmation", async () => { - vi.mocked(mockClineProvider.getCurrentCline).mockReturnValue({} as any) // Mock current cline exists + vi.mocked(mockClineProvider.getCurrentCline).mockReturnValue({ + clineMessages: [], + apiConversationHistory: [], + } as any) // Mock current cline with proper structure await webviewMessageHandler(mockClineProvider, { type: "submitEditedMessage", @@ -530,6 +537,8 @@ describe("webviewMessageHandler - message dialog preferences", () => { type: "showEditMessageDialog", messageTs: 123456789, text: "edited content", + hasCheckpoint: false, + images: undefined, }) }) }) diff --git a/src/core/webview/checkpointRestoreHandler.ts b/src/core/webview/checkpointRestoreHandler.ts new file mode 100644 index 0000000000..ac86f0c4a0 --- /dev/null +++ b/src/core/webview/checkpointRestoreHandler.ts @@ -0,0 +1,104 @@ +import { Task } from "../task/Task" +import { ClineProvider } from "./ClineProvider" +import { saveTaskMessages } from "../task-persistence" +import * as vscode from "vscode" +import pWaitFor from "p-wait-for" +import { t } from "../../i18n" + +export interface CheckpointRestoreConfig { + provider: ClineProvider + currentCline: Task + messageTs: number + messageIndex: number + checkpoint: { hash: string } + operation: "delete" | "edit" + editData?: { + editedContent: string + images?: string[] + apiConversationHistoryIndex: number + } +} + +/** + * Handles checkpoint restoration for both delete and edit operations. + * This consolidates the common logic while handling operation-specific behavior. + */ +export async function handleCheckpointRestoreOperation(config: CheckpointRestoreConfig): Promise { + const { provider, currentCline, messageTs, checkpoint, operation, editData } = config + + try { + // For delete operations, ensure the task is properly aborted to handle any pending ask operations + // This prevents "Current ask promise was ignored" errors + // For edit operations, we don't abort because the checkpoint restore will handle it + if (operation === "delete" && currentCline && !currentCline.abort) { + currentCline.abortTask() + // Wait a bit for the abort to complete + await pWaitFor(() => currentCline.abort === true, { + timeout: 1000, + interval: 50, + }).catch(() => { + // Continue even if timeout - the abort flag should be set + }) + } + + // For edit operations, set up pending edit data before restoration + if (operation === "edit" && editData) { + const operationId = `task-${currentCline.taskId}` + provider.setPendingEditOperation(operationId, { + messageTs, + editedContent: editData.editedContent, + images: editData.images, + messageIndex: config.messageIndex, + apiConversationHistoryIndex: editData.apiConversationHistoryIndex, + }) + } + + // Perform the checkpoint restoration + await currentCline.checkpointRestore({ + ts: messageTs, + commitHash: checkpoint.hash, + mode: "restore", + operation, + }) + + // For delete operations, we need to save messages and reinitialize + // For edit operations, the reinitialization happens automatically + // and processes the pending edit + if (operation === "delete") { + // Save the updated messages to disk after checkpoint restoration + await saveTaskMessages({ + messages: currentCline.clineMessages, + taskId: currentCline.taskId, + globalStoragePath: provider.contextProxy.globalStorageUri.fsPath, + }) + + // Get the updated history item and reinitialize + const { historyItem } = await provider.getTaskWithId(currentCline.taskId) + await provider.initClineWithHistoryItem(historyItem) + } + // For edit operations, the task cancellation in checkpointRestore + // will trigger reinitialization, which will process pendingEditAfterRestore + } catch (error) { + console.error(`Error in checkpoint restore (${operation}):`, error) + vscode.window.showErrorMessage( + `Error during checkpoint restore: ${error instanceof Error ? error.message : String(error)}`, + ) + throw error + } +} + +/** + * Common checkpoint restore validation and initialization utility. + * This can be used by any checkpoint restore flow that needs to wait for initialization. + */ +export async function waitForClineInitialization(provider: ClineProvider, timeoutMs: number = 3000): Promise { + try { + await pWaitFor(() => provider.getCurrentCline()?.isInitialized === true, { + timeout: timeoutMs, + }) + return true + } catch (error) { + vscode.window.showErrorMessage(t("common:errors.checkpoint_timeout")) + return false + } +} diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c739c2ade8..2fc978b112 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -16,8 +16,10 @@ import { import { CloudService } from "@roo-code/cloud" import { TelemetryService } from "@roo-code/telemetry" import { type ApiMessage } from "../task-persistence/apiMessages" +import { saveTaskMessages } from "../task-persistence" import { ClineProvider } from "./ClineProvider" +import { handleCheckpointRestoreOperation } from "./checkpointRestoreHandler" import { changeLanguage, t } from "../../i18n" import { Package } from "../../shared/package" import { RouterName, toRouterName, ModelRecord } from "../../shared/api" @@ -69,10 +71,10 @@ export const webviewMessageHandler = async ( * Shared utility to find message indices based on timestamp */ const findMessageIndices = (messageTs: number, currentCline: any) => { - const timeCutoff = messageTs - 1000 // 1 second buffer before the message - const messageIndex = currentCline.clineMessages.findIndex((msg: ClineMessage) => msg.ts && msg.ts >= timeCutoff) + // Find the exact message by timestamp, not the first one after a cutoff + const messageIndex = currentCline.clineMessages.findIndex((msg: ClineMessage) => msg.ts === messageTs) const apiConversationHistoryIndex = currentCline.apiConversationHistory.findIndex( - (msg: ApiMessage) => msg.ts && msg.ts >= timeCutoff, + (msg: ApiMessage) => msg.ts === messageTs, ) return { messageIndex, apiConversationHistoryIndex } } @@ -99,38 +101,110 @@ export const webviewMessageHandler = async ( * Handles message deletion operations with user confirmation */ const handleDeleteOperation = async (messageTs: number): Promise => { + // Check if there's a checkpoint before this message + const currentCline = provider.getCurrentCline() + let hasCheckpoint = false + if (currentCline) { + const { messageIndex } = findMessageIndices(messageTs, currentCline) + if (messageIndex !== -1) { + // Find the last checkpoint before this message + const checkpoints = currentCline.clineMessages + .filter((msg) => msg.say === "checkpoint_saved" && msg.ts < messageTs) + .sort((a, b) => b.ts - a.ts) + + hasCheckpoint = checkpoints.length > 0 + } else { + console.log("[webviewMessageHandler] Message not found! Looking for ts:", messageTs) + } + } + // Send message to webview to show delete confirmation dialog await provider.postMessageToWebview({ type: "showDeleteMessageDialog", messageTs, + hasCheckpoint, }) } /** * Handles confirmed message deletion from webview dialog */ - const handleDeleteMessageConfirm = async (messageTs: number): Promise => { - // Only proceed if we have a current cline - if (provider.getCurrentCline()) { - const currentCline = provider.getCurrentCline()! - const { messageIndex, apiConversationHistoryIndex } = findMessageIndices(messageTs, currentCline) + const handleDeleteMessageConfirm = async (messageTs: number, restoreCheckpoint?: boolean): Promise => { + const currentCline = provider.getCurrentCline() + if (!currentCline) { + console.error("[handleDeleteMessageConfirm] No current cline available") + return + } - if (messageIndex !== -1) { - try { - const { historyItem } = await provider.getTaskWithId(currentCline.taskId) + const { messageIndex, apiConversationHistoryIndex } = findMessageIndices(messageTs, currentCline) - // Delete this message and all subsequent messages - await removeMessagesThisAndSubsequent(currentCline, messageIndex, apiConversationHistoryIndex) + if (messageIndex === -1) { + const errorMessage = `Message with timestamp ${messageTs} not found` + console.error("[handleDeleteMessageConfirm]", errorMessage) + await vscode.window.showErrorMessage(errorMessage) + return + } - // Initialize with history item after deletion - await provider.initClineWithHistoryItem(historyItem) - } catch (error) { - console.error("Error in delete message:", error) - vscode.window.showErrorMessage( - `Error deleting message: ${error instanceof Error ? error.message : String(error)}`, - ) + try { + const targetMessage = currentCline.clineMessages[messageIndex] + + // If checkpoint restoration is requested, find and restore to the last checkpoint before this message + if (restoreCheckpoint) { + // Find the last checkpoint before this message + const checkpoints = currentCline.clineMessages + .filter((msg) => msg.say === "checkpoint_saved" && msg.ts < messageTs) + .sort((a, b) => b.ts - a.ts) + + const lastCheckpoint = checkpoints[0] + + if (lastCheckpoint && lastCheckpoint.text) { + await handleCheckpointRestoreOperation({ + provider, + currentCline, + messageTs: targetMessage.ts!, + messageIndex, + checkpoint: { hash: lastCheckpoint.text }, + operation: "delete", + }) + } else { + // No checkpoint found before this message + console.log("[handleDeleteMessageConfirm] No checkpoint found before message") + vscode.window.showWarningMessage("No checkpoint found before this message") + } + } else { + // For non-checkpoint deletes, preserve checkpoint associations for remaining messages + // Store checkpoints from messages that will be preserved + const preservedCheckpoints = new Map() + for (let i = 0; i < messageIndex; i++) { + const msg = currentCline.clineMessages[i] + if (msg?.checkpoint && msg.ts) { + preservedCheckpoints.set(msg.ts, msg.checkpoint) + } + } + + // Delete this message and all subsequent messages + await removeMessagesThisAndSubsequent(currentCline, messageIndex, apiConversationHistoryIndex) + + // Restore checkpoint associations for preserved messages + for (const [ts, checkpoint] of preservedCheckpoints) { + const msgIndex = currentCline.clineMessages.findIndex((msg) => msg.ts === ts) + if (msgIndex !== -1) { + currentCline.clineMessages[msgIndex].checkpoint = checkpoint + } } + + // Save the updated messages with restored checkpoints + await saveTaskMessages({ + messages: currentCline.clineMessages, + taskId: currentCline.taskId, + globalStoragePath: provider.contextProxy.globalStorageUri.fsPath, + }) } + } catch (error) { + console.error("Error in delete message:", error) + vscode.window.showErrorMessage( + `Error deleting message: ${error instanceof Error ? error.message : String(error)}`, + ) } } @@ -138,11 +212,31 @@ export const webviewMessageHandler = async ( * Handles message editing operations with user confirmation */ const handleEditOperation = async (messageTs: number, editedContent: string, images?: string[]): Promise => { + // Check if there's a checkpoint before this message + const currentCline = provider.getCurrentCline() + let hasCheckpoint = false + if (currentCline) { + const { messageIndex } = findMessageIndices(messageTs, currentCline) + if (messageIndex !== -1) { + // Find the last checkpoint before this message + const checkpoints = currentCline.clineMessages + .filter((msg) => msg.say === "checkpoint_saved" && msg.ts < messageTs) + .sort((a, b) => b.ts - a.ts) + + hasCheckpoint = checkpoints.length > 0 + } else { + console.log("[webviewMessageHandler] Edit - Message not found in clineMessages!") + } + } else { + console.log("[webviewMessageHandler] Edit - No currentCline available!") + } + // Send message to webview to show edit confirmation dialog await provider.postMessageToWebview({ type: "showEditMessageDialog", messageTs, text: editedContent, + hasCheckpoint, images, }) } @@ -153,38 +247,105 @@ export const webviewMessageHandler = async ( const handleEditMessageConfirm = async ( messageTs: number, editedContent: string, + restoreCheckpoint?: boolean, images?: string[], ): Promise => { - // Only proceed if we have a current cline - if (provider.getCurrentCline()) { - const currentCline = provider.getCurrentCline()! + const currentCline = provider.getCurrentCline() + if (!currentCline) { + console.error("[handleEditMessageConfirm] No current cline available") + return + } - // Use findMessageIndices to find messages based on timestamp - const { messageIndex, apiConversationHistoryIndex } = findMessageIndices(messageTs, currentCline) + // Use findMessageIndices to find messages based on timestamp + const { messageIndex, apiConversationHistoryIndex } = findMessageIndices(messageTs, currentCline) - if (messageIndex !== -1) { - try { - // Edit this message and delete subsequent - await removeMessagesThisAndSubsequent(currentCline, messageIndex, apiConversationHistoryIndex) - - // Process the edited message as a regular user message - // This will add it to the conversation and trigger an AI response - webviewMessageHandler(provider, { - type: "askResponse", - askResponse: "messageResponse", - text: editedContent, - images, + if (messageIndex === -1) { + const errorMessage = `Message with timestamp ${messageTs} not found` + console.error("[handleEditMessageConfirm]", errorMessage) + await vscode.window.showErrorMessage(errorMessage) + return + } + + try { + const targetMessage = currentCline.clineMessages[messageIndex] + + // If checkpoint restoration is requested, find and restore to the last checkpoint before this message + if (restoreCheckpoint) { + // Find the last checkpoint before this message + const checkpoints = currentCline.clineMessages + .filter((msg) => msg.say === "checkpoint_saved" && msg.ts < messageTs) + .sort((a, b) => b.ts - a.ts) + + const lastCheckpoint = checkpoints[0] + + if (lastCheckpoint && lastCheckpoint.text) { + await handleCheckpointRestoreOperation({ + provider, + currentCline, + messageTs: targetMessage.ts!, + messageIndex, + checkpoint: { hash: lastCheckpoint.text }, + operation: "edit", + editData: { + editedContent, + images, + apiConversationHistoryIndex, + }, }) + // The task will be cancelled and reinitialized by checkpointRestore + // The pending edit will be processed in the reinitialized task + return + } else { + // No checkpoint found before this message + console.log("[handleEditMessageConfirm] No checkpoint found before message") + vscode.window.showWarningMessage("No checkpoint found before this message") + // Continue with non-checkpoint edit + } + } - // Don't initialize with history item for edit operations - // The webviewMessageHandler will handle the conversation state - } catch (error) { - console.error("Error in edit message:", error) - vscode.window.showErrorMessage( - `Error editing message: ${error instanceof Error ? error.message : String(error)}`, - ) + // For non-checkpoint edits, preserve checkpoint associations for remaining messages + // Store checkpoints from messages that will be preserved + const preservedCheckpoints = new Map() + for (let i = 0; i < messageIndex; i++) { + const msg = currentCline.clineMessages[i] + if (msg?.checkpoint && msg.ts) { + preservedCheckpoints.set(msg.ts, msg.checkpoint) } } + + // Edit this message and delete subsequent + await removeMessagesThisAndSubsequent(currentCline, messageIndex, apiConversationHistoryIndex) + + // Restore checkpoint associations for preserved messages + for (const [ts, checkpoint] of preservedCheckpoints) { + const msgIndex = currentCline.clineMessages.findIndex((msg) => msg.ts === ts) + if (msgIndex !== -1) { + currentCline.clineMessages[msgIndex].checkpoint = checkpoint + } + } + + // Save the updated messages with restored checkpoints + await saveTaskMessages({ + messages: currentCline.clineMessages, + taskId: currentCline.taskId, + globalStoragePath: provider.contextProxy.globalStorageUri.fsPath, + }) + + // Process the edited message as a regular user message + webviewMessageHandler(provider, { + type: "askResponse", + askResponse: "messageResponse", + text: editedContent, + images, + }) + + // Don't initialize with history item for edit operations + // The webviewMessageHandler will handle the conversation state + } catch (error) { + console.error("Error in edit message:", error) + vscode.window.showErrorMessage( + `Error editing message: ${error instanceof Error ? error.message : String(error)}`, + ) } } @@ -1568,12 +1729,17 @@ export const webviewMessageHandler = async ( break case "deleteMessageConfirm": if (message.messageTs) { - await handleDeleteMessageConfirm(message.messageTs) + await handleDeleteMessageConfirm(message.messageTs, message.restoreCheckpoint) } break case "editMessageConfirm": if (message.messageTs && message.text) { - await handleEditMessageConfirm(message.messageTs, message.text, message.images) + await handleEditMessageConfirm( + message.messageTs, + message.text, + message.restoreCheckpoint, + message.images, + ) } break case "getListApiConfiguration": diff --git a/src/services/checkpoints/ShadowCheckpointService.ts b/src/services/checkpoints/ShadowCheckpointService.ts index be2c86852a..89ba1a54ab 100644 --- a/src/services/checkpoints/ShadowCheckpointService.ts +++ b/src/services/checkpoints/ShadowCheckpointService.ts @@ -20,6 +20,7 @@ export abstract class ShadowCheckpointService extends EventEmitter { protected _checkpoints: string[] = [] protected _baseHash?: string + protected _hasFirstCheckpoint: boolean = false protected readonly dotGitDir: string protected git?: SimpleGit @@ -93,6 +94,11 @@ export abstract class ShadowCheckpointService extends EventEmitter { await this.writeExcludeFile() this.baseHash = await git.revparse(["HEAD"]) + // Check if there are any commits beyond the initial commit + const commits = await git.log() + if (commits.total > 0) { + this._hasFirstCheckpoint = true + } } else { this.log(`[${this.constructor.name}#initShadowGit] creating shadow git repo at ${this.checkpointsDir}`) await git.init() @@ -211,7 +217,7 @@ export abstract class ShadowCheckpointService extends EventEmitter { await this.stageAll(this.git) const commitArgs = options?.allowEmpty ? { "--allow-empty": null } : undefined const result = await this.git.commit(message, commitArgs) - const isFirst = this._checkpoints.length === 0 + const isFirst = !this._hasFirstCheckpoint const fromHash = this._checkpoints[this._checkpoints.length - 1] ?? this.baseHash! const toHash = result.commit || fromHash this._checkpoints.push(toHash) @@ -219,6 +225,9 @@ export abstract class ShadowCheckpointService extends EventEmitter { if (isFirst || result.commit) { this.emit("checkpoint", { type: "checkpoint", isFirst, fromHash, toHash, duration }) + if (isFirst) { + this._hasFirstCheckpoint = true + } } if (result.commit) { diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 000762e317..ce0fc5bc99 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -184,6 +184,7 @@ export interface ExtensionMessage { rulesFolderPath?: string settings?: any messageTs?: number + hasCheckpoint?: boolean context?: string } diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 795e276522..6984f7fa68 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -235,6 +235,7 @@ export interface WebviewMessage { hasSystemPromptOverride?: boolean terminalOperation?: "continue" | "abort" messageTs?: number + restoreCheckpoint?: boolean historyPreviewCollapsed?: boolean filters?: { type?: string; search?: string; tags?: string[] } url?: string // For openExternal diff --git a/webview-ui/src/App.tsx b/webview-ui/src/App.tsx index 3782242707..11913f6892 100644 --- a/webview-ui/src/App.tsx +++ b/webview-ui/src/App.tsx @@ -19,6 +19,7 @@ import McpView from "./components/mcp/McpView" import { MarketplaceView } from "./components/marketplace/MarketplaceView" import ModesView from "./components/modes/ModesView" import { HumanRelayDialog } from "./components/human-relay/HumanRelayDialog" +import { CheckpointRestoreDialog } from "./components/chat/CheckpointRestoreDialog" import { DeleteMessageDialog, EditMessageDialog } from "./components/chat/MessageModificationConfirmationDialog" import ErrorBoundary from "./components/ErrorBoundary" import { AccountView } from "./components/account/AccountView" @@ -37,18 +38,21 @@ interface HumanRelayDialogState { interface DeleteMessageDialogState { isOpen: boolean messageTs: number + hasCheckpoint: boolean } interface EditMessageDialogState { isOpen: boolean messageTs: number text: string + hasCheckpoint: boolean images?: string[] } // Memoize dialog components to prevent unnecessary re-renders const MemoizedDeleteMessageDialog = React.memo(DeleteMessageDialog) const MemoizedEditMessageDialog = React.memo(EditMessageDialog) +const MemoizedCheckpointRestoreDialog = React.memo(CheckpointRestoreDialog) const MemoizedHumanRelayDialog = React.memo(HumanRelayDialog) const tabsByMessageAction: Partial, Tab>> = { @@ -91,12 +95,14 @@ const App = () => { const [deleteMessageDialogState, setDeleteMessageDialogState] = useState({ isOpen: false, messageTs: 0, + hasCheckpoint: false, }) const [editMessageDialogState, setEditMessageDialogState] = useState({ isOpen: false, messageTs: 0, text: "", + hasCheckpoint: false, images: [], }) @@ -156,7 +162,11 @@ const App = () => { } if (message.type === "showDeleteMessageDialog" && message.messageTs) { - setDeleteMessageDialogState({ isOpen: true, messageTs: message.messageTs }) + setDeleteMessageDialogState({ + isOpen: true, + messageTs: message.messageTs, + hasCheckpoint: message.hasCheckpoint || false, + }) } if (message.type === "showEditMessageDialog" && message.messageTs && message.text) { @@ -164,6 +174,7 @@ const App = () => { isOpen: true, messageTs: message.messageTs, text: message.text, + hasCheckpoint: message.hasCheckpoint || false, images: message.images || [], }) } @@ -268,30 +279,65 @@ const App = () => { onSubmit={(requestId, text) => vscode.postMessage({ type: "humanRelayResponse", requestId, text })} onCancel={(requestId) => vscode.postMessage({ type: "humanRelayCancel", requestId })} /> - setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: open }))} - onConfirm={() => { - vscode.postMessage({ - type: "deleteMessageConfirm", - messageTs: deleteMessageDialogState.messageTs, - }) - setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: false })) - }} - /> - setEditMessageDialogState((prev) => ({ ...prev, isOpen: open }))} - onConfirm={() => { - vscode.postMessage({ - type: "editMessageConfirm", - messageTs: editMessageDialogState.messageTs, - text: editMessageDialogState.text, - images: editMessageDialogState.images, - }) - setEditMessageDialogState((prev) => ({ ...prev, isOpen: false })) - }} - /> + {deleteMessageDialogState.hasCheckpoint ? ( + setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: open }))} + onConfirm={(restoreCheckpoint: boolean) => { + vscode.postMessage({ + type: "deleteMessageConfirm", + messageTs: deleteMessageDialogState.messageTs, + restoreCheckpoint, + }) + setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: false })) + }} + /> + ) : ( + setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: open }))} + onConfirm={() => { + vscode.postMessage({ + type: "deleteMessageConfirm", + messageTs: deleteMessageDialogState.messageTs, + }) + setDeleteMessageDialogState((prev) => ({ ...prev, isOpen: false })) + }} + /> + )} + {editMessageDialogState.hasCheckpoint ? ( + setEditMessageDialogState((prev) => ({ ...prev, isOpen: open }))} + onConfirm={(restoreCheckpoint: boolean) => { + vscode.postMessage({ + type: "editMessageConfirm", + messageTs: editMessageDialogState.messageTs, + text: editMessageDialogState.text, + restoreCheckpoint, + }) + setEditMessageDialogState((prev) => ({ ...prev, isOpen: false })) + }} + /> + ) : ( + setEditMessageDialogState((prev) => ({ ...prev, isOpen: open }))} + onConfirm={() => { + vscode.postMessage({ + type: "editMessageConfirm", + messageTs: editMessageDialogState.messageTs, + text: editMessageDialogState.text, + images: editMessageDialogState.images, + }) + setEditMessageDialogState((prev) => ({ ...prev, isOpen: false })) + }} + /> + )} ) } diff --git a/webview-ui/src/components/chat/ChatRow.tsx b/webview-ui/src/components/chat/ChatRow.tsx index 4fa921f443..f20eec1c23 100644 --- a/webview-ui/src/components/chat/ChatRow.tsx +++ b/webview-ui/src/components/chat/ChatRow.tsx @@ -1090,7 +1090,7 @@ export const ChatRowContent = ({