diff --git a/example/convex/_generated/api.d.ts b/example/convex/_generated/api.d.ts index 7a9c3698..6739088c 100644 --- a/example/convex/_generated/api.d.ts +++ b/example/convex/_generated/api.d.ts @@ -233,6 +233,7 @@ export declare const components: { vectors: Array | null>; }; failPendingSteps?: boolean; + hideFromUserIdSearch?: boolean; messages: Array<{ error?: string; fileIds?: Array; @@ -836,6 +837,22 @@ export declare const components: { }>; } >; + cloneThread: FunctionReference< + "action", + "internal", + { + batchSize?: number; + copyUserIdForVectorSearch?: boolean; + excludeToolMessages?: boolean; + insertAtOrder?: number; + limit?: number; + sourceThreadId: string; + statuses?: Array<"pending" | "success" | "failed">; + targetThreadId: string; + upToAndIncludingMessageId?: string; + }, + number + >; deleteByIds: FunctionReference< "mutation", "internal", diff --git a/src/component/_generated/api.d.ts b/src/component/_generated/api.d.ts index 3257e2f7..ca538cff 100644 --- a/src/component/_generated/api.d.ts +++ b/src/component/_generated/api.d.ts @@ -141,6 +141,7 @@ export type Mounts = { vectors: Array | null>; }; failPendingSteps?: boolean; + hideFromUserIdSearch?: boolean; messages: Array<{ error?: string; fileIds?: Array; @@ -662,6 +663,22 @@ export type Mounts = { }>; } >; + cloneThread: FunctionReference< + "action", + "public", + { + batchSize?: number; + copyUserIdForVectorSearch?: boolean; + excludeToolMessages?: boolean; + insertAtOrder?: number; + limit?: number; + sourceThreadId: string; + statuses?: Array<"pending" | "success" | "failed">; + targetThreadId: string; + upToAndIncludingMessageId?: string; + }, + number + >; deleteByIds: FunctionReference< "mutation", "public", diff --git a/src/component/messages.ts b/src/component/messages.ts index 380e1022..9af93fd8 100644 --- a/src/component/messages.ts +++ b/src/component/messages.ts @@ -24,6 +24,7 @@ import { api, internal } from "./_generated/api.js"; import type { Doc, Id } from "./_generated/dataModel.js"; import { action, + internalMutation, internalQuery, mutation, type MutationCtx, @@ -130,6 +131,9 @@ const addMessagesArgs = { failPendingSteps: v.optional(v.boolean()), // A pending message to update. If the pending message failed, abort. pendingMessageId: v.optional(v.id("messages")), + // if set to true, these messages will not show up in text or vector search + // results for the userId + hideFromUserIdSearch: v.optional(v.boolean()), }; export const addMessages = mutation({ args: addMessagesArgs, @@ -153,6 +157,7 @@ async function addMessagesHandler( messages, promptMessageId, pendingMessageId, + hideFromUserIdSearch, ...rest } = args; const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId)); @@ -219,7 +224,7 @@ async function addMessagesHandler( vector: embeddings.vectors[i]!, model: embeddings.model, table: "messages", - userId, + userId: hideFromUserIdSearch ? undefined : userId, threadId, }); } @@ -230,7 +235,7 @@ async function addMessagesHandler( parentMessageId: promptMessageId, userId, tool: isTool(message.message), - text: extractText(message.message), + text: hideFromUserIdSearch ? undefined : extractText(message.message), status: fail ? "failed" : (message.status ?? "success"), error: fail ? error : message.error, } satisfies Omit< @@ -432,64 +437,216 @@ export const updateMessage = mutation({ }, }); -export const listMessagesByThreadId = query({ +const cloneMessageArgs = { + sourceThreadId: v.id("threads"), + targetThreadId: v.id("threads"), + // defaults to false, so searching for a message by userId will not find + // these copies + copyUserIdForVectorSearch: v.optional(v.boolean()), + // defaults to false, so tool calls & responses will be copied + excludeToolMessages: v.optional(v.boolean()), + // defaults to copying all messages, but you could just copy success messages. + statuses: v.optional(v.array(vMessageStatus)), + // stop at this message id + upToAndIncludingMessageId: v.optional(v.id("messages")), + // defaults to 0. the messages will be inserted starting at this order. + insertAtOrder: v.optional(v.number()), +}; +export const cloneMessageBatch = internalMutation({ args: { - threadId: v.id("threads"), - excludeToolMessages: v.optional(v.boolean()), - /** What order to sort the messages in. To get the latest, use "desc". */ - order: v.union(v.literal("asc"), v.literal("desc")), - paginationOpts: v.optional(paginationOptsValidator), - statuses: v.optional(v.array(vMessageStatus)), - upToAndIncludingMessageId: v.optional(v.id("messages")), + ...cloneMessageArgs, + paginationOpts: paginationOptsValidator, }, - handler: async (ctx, args) => { - const statuses = - args.statuses ?? vMessageStatus.members.map((m) => m.value); - const last = - args.upToAndIncludingMessageId && - (await ctx.db.get(args.upToAndIncludingMessageId)); - assert( - !last || last.threadId === args.threadId, - "upToAndIncludingMessageId must be a message in the thread", - ); - const toolOptions = args.excludeToolMessages ? [false] : [true, false]; - const order = args.order ?? "desc"; - const streams = toolOptions.flatMap((tool) => - statuses.map((status) => - stream(ctx.db, schema) - .query("messages") - .withIndex("threadId_status_tool_order_stepOrder", (q) => { - const qq = q - .eq("threadId", args.threadId) - .eq("status", status) - .eq("tool", tool); - if (last) { - return qq.lte("order", last.order); - } - return qq; - }) - .order(order) - .filterWith( - // We allow all messages on the same order. - async (m) => - !last || m.order < last.order || m.order === last.order, - ), - ), - ); - const messages = await mergedStream(streams, [ - "order", - "stepOrder", - ]).paginate( - args.paginationOpts ?? { - numItems: DEFAULT_RECENT_MESSAGES, - cursor: null, - }, + handler: async ( + ctx, + args, + ): Promise<{ + numCopied: number; + continueCursor: string; + isDone: boolean; + }> => { + const orderOffset = args.insertAtOrder ?? 0; + const result = await listMessagesByThreadIdHandler(ctx, { + threadId: args.sourceThreadId, + excludeToolMessages: args.excludeToolMessages, + order: "desc", + paginationOpts: args.paginationOpts, + statuses: args.statuses, + upToAndIncludingMessageId: args.upToAndIncludingMessageId, + }); + + const existing = + result.page.length === 0 + ? [] + : await mergedStream( + [true, false].flatMap((tool) => + messageStatuses.map((status) => + stream(ctx.db, schema) + .query("messages") + .withIndex("threadId_status_tool_order_stepOrder", (q) => + q + .eq("threadId", args.targetThreadId) + .eq("status", status) + .eq("tool", tool) + .gte("order", result.page[0].order) + .lte("order", result.page[result.page.length - 1].order), + ), + ), + ), + ["order", "stepOrder"], + ).collect(); + + await Promise.all( + result.page + .filter( + (m) => + !existing.some( + (e) => e.order === m.order && e.stepOrder === m.stepOrder, + ), + ) + .map(async (m) => { + // update file refs + if (m.fileIds) { + await changeRefcount(ctx, [], m.fileIds); + } + let embeddingId: VectorTableId | undefined = undefined; + if (m.embeddingId) { + const vector = await ctx.db.get(m.embeddingId); + assert(vector, `Vector ${m.embeddingId} not found`); + const dimension = vector.vector.length; + validateVectorDimension(dimension); + embeddingId = await insertVector(ctx, dimension, { + ...pick(vector, ["model", "table", "vector"]), + userId: args.copyUserIdForVectorSearch + ? vector.userId + : undefined, + threadId: args.targetThreadId, + }); + } + await ctx.db.insert("messages", { + ...omit(m, [ + "_id", + "_creationTime", + "threadId", + "order", + "embeddingId", + ]), + embeddingId, + threadId: args.targetThreadId, + order: orderOffset + m.order, + }); + }), ); + return { + numCopied: result.page.length, + continueCursor: result.continueCursor, + isDone: result.isDone, + }; + }, +}); + +export const cloneThread = action({ + args: { + ...cloneMessageArgs, + batchSize: v.optional(v.number()), + // how many messages to copy + limit: v.optional(v.number()), + }, + returns: v.number(), + handler: async (ctx, args) => { + let cursor: string | null = null; + let copiedSoFar = 0; + while (copiedSoFar < (args.limit ?? Infinity)) { + const numToCopy = Math.min( + args.batchSize ?? DEFAULT_RECENT_MESSAGES, + args.limit ?? Infinity - copiedSoFar, + ); + const result: { + numCopied: number; + continueCursor: string; + isDone: boolean; + } = await ctx.runMutation(internal.messages.cloneMessageBatch, { + ...args, + paginationOpts: { + cursor, + numItems: numToCopy, + }, + }); + copiedSoFar += result.numCopied; + cursor = result.continueCursor; + if (result.isDone) { + break; + } + } + return copiedSoFar; + }, +}); + +export const listMessagesByThreadIdArgs = { + threadId: v.id("threads"), + excludeToolMessages: v.optional(v.boolean()), + /** What order to sort the messages in. To get the latest, use "desc". */ + order: v.union(v.literal("asc"), v.literal("desc")), + paginationOpts: v.optional(paginationOptsValidator), + statuses: v.optional(v.array(vMessageStatus)), + upToAndIncludingMessageId: v.optional(v.id("messages")), +}; +export const listMessagesByThreadId = query({ + args: listMessagesByThreadIdArgs, + handler: async (ctx, args) => { + const messages = await listMessagesByThreadIdHandler(ctx, args); return { ...messages, page: messages.page.map(publicMessage) }; }, returns: vPaginationResult(vMessageDoc), }); +async function listMessagesByThreadIdHandler( + ctx: QueryCtx, + args: ObjectType, +) { + const statuses = args.statuses ?? vMessageStatus.members.map((m) => m.value); + const last = + args.upToAndIncludingMessageId && + (await ctx.db.get(args.upToAndIncludingMessageId)); + assert( + !last || last.threadId === args.threadId, + "upToAndIncludingMessageId must be a message in the thread", + ); + const toolOptions = args.excludeToolMessages ? [false] : [true, false]; + const order = args.order ?? "desc"; + const streams = toolOptions.flatMap((tool) => + statuses.map((status) => + stream(ctx.db, schema) + .query("messages") + .withIndex("threadId_status_tool_order_stepOrder", (q) => { + const qq = q + .eq("threadId", args.threadId) + .eq("status", status) + .eq("tool", tool); + if (last) { + return qq.lte("order", last.order); + } + return qq; + }) + .order(order) + .filterWith( + // We allow all messages on the same order. + async (m) => !last || m.order <= last.order, + ), + ), + ); + const messages = await mergedStream(streams, ["order", "stepOrder"]).paginate( + args.paginationOpts ?? { + numItems: DEFAULT_RECENT_MESSAGES, + cursor: null, + }, + ); + if (messages.page.length === 0) { + messages.isDone = true; + } + return messages; +} + export const getMessagesByIds = query({ args: { messageIds: v.array(v.id("messages")) }, handler: async (ctx, args) => {