Skip to content

Commit 0b7c978

Browse files
committed
pass tool values through more places
1 parent f6db4a1 commit 0b7c978

File tree

2 files changed

+39
-43
lines changed

2 files changed

+39
-43
lines changed

src/UIMessages.ts

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,9 @@ function createAssistantUIMessage<
381381
const allParts: UIMessage<METADATA, DATA_PARTS, TOOLS>["parts"] = [];
382382

383383
for (const message of group) {
384-
const coreMessage = message.message && toModelMessage(message.message);
385-
if (!coreMessage) continue;
384+
if (!message.message) continue;
386385

387-
const content = coreMessage.content;
386+
const content = message.message.content;
388387
const nonStringContent =
389388
content && typeof content !== "string" ? content : [];
390389
const text = extractTextFromMessageDoc(message);
@@ -441,12 +440,11 @@ function createAssistantUIMessage<
441440
type: "step-start",
442441
} satisfies StepStartUIPart);
443442
const toolPart: ToolUIPart<TOOLS> = {
443+
...omit(contentPart, ["args", "type"]),
444444
type: `tool-${contentPart.toolName as keyof TOOLS & string}`,
445-
toolCallId: contentPart.toolCallId,
446-
input: contentPart.input as DeepPartial<
445+
input: contentPart.args as DeepPartial<
447446
TOOLS[keyof TOOLS & string]["input"]
448447
>,
449-
providerExecuted: contentPart.providerExecuted,
450448
...(message.streaming
451449
? { state: "input-streaming" }
452450
: {
@@ -478,25 +476,22 @@ function createAssistantUIMessage<
478476
call.output = output;
479477
}
480478
} else {
481-
console.warn(
482-
"Tool result without preceding tool call.. adding anyways",
483-
contentPart,
484-
);
485479
if (message.error) {
486480
allParts.push({
481+
input: contentPart.args,
482+
...omit(contentPart, ["args", "type", "output"]),
487483
type: `tool-${contentPart.toolName}`,
488-
toolCallId: contentPart.toolCallId,
489484
state: "output-error",
490-
input: undefined,
485+
output,
491486
errorText: message.error,
492487
callProviderMetadata: message.providerMetadata,
493488
} satisfies ToolUIPart<TOOLS>);
494489
} else {
495490
allParts.push({
491+
input: contentPart.args,
492+
...omit(contentPart, ["args", "type", "output"]),
496493
type: `tool-${contentPart.toolName}`,
497-
toolCallId: contentPart.toolCallId,
498494
state: "output-available",
499-
input: undefined,
500495
output,
501496
callProviderMetadata: message.providerMetadata,
502497
} satisfies ToolUIPart<TOOLS>);

src/mapping.ts

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import {
5454
getProviderName,
5555
type ModelOrMetadata,
5656
} from "./shared.js";
57-
import { pick } from "convex-helpers";
57+
import { omit, pick } from "convex-helpers";
5858
export type AIMessageWithoutId = Omit<AIMessage, "id">;
5959

6060
export type SerializeUrlsAndUint8Arrays<T> = T extends URL
@@ -102,9 +102,7 @@ export function fromModelMessage(message: ModelMessage): Message {
102102
return {
103103
role: message.role,
104104
content,
105-
...(message.providerOptions
106-
? { providerOptions: message.providerOptions }
107-
: {}),
105+
...pick(message, ["providerOptions"]),
108106
} as SerializedMessage;
109107
}
110108

@@ -121,9 +119,7 @@ export async function serializeOrThrow(
121119
return {
122120
role: message.role,
123121
content,
124-
...(message.providerOptions
125-
? { providerOptions: message.providerOptions }
126-
: {}),
122+
...pick(message, ["providerOptions"]),
127123
} as SerializedMessage;
128124
}
129125

@@ -613,20 +609,20 @@ export function toModelMessageContent(
613609
mediaType: getMimeOrMediaType(part)!,
614610
...metadata,
615611
} satisfies FilePart;
616-
case "tool-call": {
617-
const input = "input" in part ? part.input : part.args;
612+
case "tool-call":
618613
return {
619-
type: part.type,
620-
input: input ?? null,
621-
toolCallId: part.toolCallId,
622-
toolName: part.toolName,
623-
providerExecuted: part.providerExecuted,
624-
...metadata,
614+
input: ("input" in part ? part.input : part.args) ?? null,
615+
...omit(part as Infer<typeof vToolCallPart>, ["args"]),
625616
} satisfies ToolCallPart;
626-
}
627-
case "tool-result": {
628-
return normalizeToolResult(part, metadata);
629-
}
617+
case "tool-result":
618+
return {
619+
input: (part as Infer<typeof vToolResultPart>).args,
620+
output: normalizeToolOutput(
621+
(part as Infer<typeof vToolResultPart>).result,
622+
),
623+
...omit(part as Infer<typeof vToolResultPart>, ["result", "args"]),
624+
...metadata,
625+
} satisfies ToolResultPart & { input: unknown };
630626
case "reasoning":
631627
return {
632628
type: part.type,
@@ -684,16 +680,13 @@ function normalizeToolResult(
684680
providerOptions?: ProviderOptions;
685681
providerMetadata?: ProviderMetadata;
686682
},
687-
): ToolResultPart & Infer<typeof vToolResultPart> {
683+
): ToolResultPart & Infer<typeof vToolResultPart> & { input: unknown } {
688684
return {
689-
type: part.type,
690-
output:
691-
part.output ??
692-
normalizeToolOutput("result" in part ? part.result : undefined),
693-
toolCallId: part.toolCallId,
694-
toolName: part.toolName,
685+
input: (part as Infer<typeof vToolResultPart>).args,
686+
output: normalizeToolOutput("result" in part ? part.result : undefined),
687+
...omit(part as Infer<typeof vToolResultPart>, ["result", "args"]),
695688
...metadata,
696-
} satisfies ToolResultPart;
689+
} satisfies ToolResultPart & { input: unknown };
697690
}
698691

699692
/**
@@ -804,16 +797,24 @@ export function toModelMessageDataOrUrl(
804797
return urlOrString;
805798
}
806799

807-
export function toUIFilePart(part: ImagePart | FilePart): FileUIPart {
800+
export function toUIFilePart(
801+
part:
802+
| ImagePart
803+
| FilePart
804+
| Infer<typeof vImagePart>
805+
| Infer<typeof vFilePart>,
806+
): FileUIPart {
808807
const dataOrUrl = part.type === "image" ? part.image : part.data;
809808
const url =
810809
dataOrUrl instanceof ArrayBuffer
811810
? convertUint8ArrayToBase64(new Uint8Array(dataOrUrl))
812811
: dataOrUrl.toString();
813812

813+
const mediaType = getMimeOrMediaType(part);
814+
814815
return {
815816
type: "file",
816-
mediaType: part.mediaType!,
817+
mediaType: mediaType!,
817818
filename: part.type === "file" ? part.filename : undefined,
818819
url,
819820
providerMetadata: part.providerOptions,

0 commit comments

Comments
 (0)