Skip to content

Commit 1f67fe4

Browse files
authoredJul 18, 2024··
feat (ai/ui): stream tool calls with streamText and useChat (#2295)
1 parent f0bc1e7 commit 1f67fe4

File tree

15 files changed

+1063
-22
lines changed

15 files changed

+1063
-22
lines changed
 

‎.changeset/moody-rice-peel.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
'@ai-sdk/ui-utils': patch
3+
'ai': patch
4+
---
5+
6+
feat (ai/ui): stream tool calls with streamText and useChat

‎content/docs/05-ai-sdk-ui/03-chatbot-with-tool-calling.mdx

+49
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,52 @@ export default function Chat() {
195195
);
196196
}
197197
```
198+
199+
## Tool call streaming
200+
201+
<Note>This feature is experimental.</Note>
202+
203+
You can stream tool calls while they are being generated by enabling the
204+
`experimental_toolCallStreaming` option in `streamText`.
205+
206+
```tsx filename='app/page.tsx' highlight="5"
207+
export async function POST(req: Request) {
208+
// ...
209+
210+
const result = await streamText({
211+
experimental_toolCallStreaming: true,
212+
// ...
213+
});
214+
215+
return result.toAIStreamResponse();
216+
}
217+
```
218+
219+
When the flag is enabled, partial tool calls will be streamed as part of the AI stream.
220+
They are available through the `useChat` hook.
221+
The `toolInvocations` property of assistant messages will also contain partial tool calls.
222+
You can use the `state` property of the tool invocation to render the correct UI.
223+
224+
```tsx filename='app/page.tsx' highlight="9,10"
225+
export default function Chat() {
226+
// ...
227+
return (
228+
<>
229+
{messages?.map((m: Message) => (
230+
<div key={m.id}>
231+
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
232+
switch (toolInvocation.state) {
233+
case 'partial-call':
234+
return <>render partial tool call</>;
235+
case 'call':
236+
return <>render full tool call</>;
237+
case 'result':
238+
return <>render tool result</>;
239+
}
240+
})}
241+
</div>
242+
))}
243+
</>
244+
);
245+
}
246+
```

‎content/docs/07-reference/ai-sdk-core/02-stream-text.mdx

+56
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,13 @@ for await (const textPart of textStream) {
366366
},
367367
],
368368
},
369+
{
370+
name: 'experimental_toolCallStreaming',
371+
type: 'boolean',
372+
isOptional: true,
373+
description:
374+
'Enable streaming of tool call deltas as they are generated. Disabled by default.',
375+
},
369376
{
370377
name: 'onFinish',
371378
type: '(result: OnFinishResult) => void',
@@ -590,6 +597,55 @@ for await (const textPart of textStream) {
590597
},
591598
],
592599
},
600+
{
601+
type: 'TextStreamPart',
602+
parameters: [
603+
{
604+
name: 'type',
605+
type: "'tool-call-streaming-start'",
606+
description:
607+
'Indicates the start of a tool call streaming. Only available when streaming tool calls.',
608+
},
609+
{
610+
name: 'toolCallId',
611+
type: 'string',
612+
description: 'The id of the tool call.',
613+
},
614+
{
615+
name: 'toolName',
616+
type: 'string',
617+
description:
618+
'The name of the tool, which typically would be the name of the function.',
619+
},
620+
],
621+
},
622+
{
623+
type: 'TextStreamPart',
624+
parameters: [
625+
{
626+
name: 'type',
627+
type: "'tool-call-delta'",
628+
description:
629+
'The type to identify the object as tool call delta. Only available when streaming tool calls.',
630+
},
631+
{
632+
name: 'toolCallId',
633+
type: 'string',
634+
description: 'The id of the tool call.',
635+
},
636+
{
637+
name: 'toolName',
638+
type: 'string',
639+
description:
640+
'The name of the tool, which typically would be the name of the function.',
641+
},
642+
{
643+
name: 'argsTextDelta',
644+
type: 'string',
645+
description: 'The text delta of the tool call arguments.',
646+
},
647+
],
648+
},
593649
{
594650
type: 'TextStreamPart',
595651
description: 'The result of a tool call execution.',

‎content/docs/07-reference/ai-sdk-ui/01-use-chat.mdx

+98
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,104 @@ Allows you to easily create a conversational user interface for your chatbot app
202202
description:
203203
'Additional annotations sent along with the message.',
204204
},
205+
{
206+
name: 'toolInvocations',
207+
type: 'Array<ToolInvocation>',
208+
isOptional: true,
209+
description:
210+
'An array of tool invocations that are associated with the (assistant) message.',
211+
properties: [
212+
{
213+
type: 'ToolInvocation',
214+
parameters: [
215+
{
216+
name: 'state',
217+
type: "'partial-call'",
218+
description:
219+
'The state of the tool call when it was partially created.',
220+
},
221+
{
222+
name: 'toolCallId',
223+
type: 'string',
224+
description:
225+
'ID of the tool call. This ID is used to match the tool call with the tool result.',
226+
},
227+
{
228+
name: 'toolName',
229+
type: 'string',
230+
description: 'Name of the tool that is being called.',
231+
},
232+
{
233+
name: 'args',
234+
type: 'any',
235+
description:
236+
'Partial arguments of the tool call. This is a JSON-serializable object.',
237+
},
238+
],
239+
},
240+
{
241+
type: 'ToolInvocation',
242+
parameters: [
243+
{
244+
name: 'state',
245+
type: "'call'",
246+
description:
247+
'The state of the tool call when it was fully created.',
248+
},
249+
{
250+
name: 'toolCallId',
251+
type: 'string',
252+
description:
253+
'ID of the tool call. This ID is used to match the tool call with the tool result.',
254+
},
255+
{
256+
name: 'toolName',
257+
type: 'string',
258+
description: 'Name of the tool that is being called.',
259+
},
260+
{
261+
name: 'args',
262+
type: 'any',
263+
description:
264+
'Arguments of the tool call. This is a JSON-serializable object that matches the tools input schema.',
265+
},
266+
],
267+
},
268+
{
269+
type: 'ToolInvocation',
270+
parameters: [
271+
{
272+
name: 'state',
273+
type: "'result'",
274+
description:
275+
'The state of the tool call when the result is available.',
276+
},
277+
{
278+
name: 'toolCallId',
279+
type: 'string',
280+
description:
281+
'ID of the tool call. This ID is used to match the tool call with the tool result.',
282+
},
283+
{
284+
name: 'toolName',
285+
type: 'string',
286+
description: 'Name of the tool that is being called.',
287+
},
288+
{
289+
name: 'args',
290+
type: 'any',
291+
description:
292+
'Arguments of the tool call. This is a JSON-serializable object that matches the tools input schema.',
293+
},
294+
{
295+
name: 'result',
296+
type: 'any',
297+
description: 'The result of the tool call.',
298+
},
299+
],
300+
},
301+
],
302+
},
205303
{
206304
name: 'experimental_attachments',
207305
type: 'Array<Attachment>',

‎examples/next-openai/app/api/use-chat-tools/route.ts

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export async function POST(req: Request) {
1111
const result = await streamText({
1212
model: openai('gpt-4-turbo'),
1313
messages: convertToCoreMessages(messages),
14+
experimental_toolCallStreaming: true,
1415
tools: {
1516
// server-side tool with execute function:
1617
getWeatherInformation: {

‎examples/next-openai/app/use-chat-tools/page.tsx

+9
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ export default function Chat() {
3232
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
3333
const toolCallId = toolInvocation.toolCallId;
3434

35+
// example of pre-rendering streaming tool calls
36+
if (toolInvocation.state === 'partial-call') {
37+
return (
38+
<pre key={toolCallId}>
39+
{JSON.stringify(toolInvocation, null, 2)}
40+
</pre>
41+
);
42+
}
43+
3544
// render confirmation tool (client-side tool with user interaction)
3645
if (toolInvocation.toolName === 'askForConfirmation') {
3746
return (

‎packages/core/core/generate-text/run-tools-transformation.ts

+28-5
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import { parseToolCall } from './tool-call';
1010
export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
1111
tools,
1212
generatorStream,
13+
toolCallStreaming,
1314
tracer,
1415
}: {
1516
tools?: TOOLS;
1617
generatorStream: ReadableStream<LanguageModelV1StreamPart>;
18+
toolCallStreaming: boolean;
1719
tracer: Tracer;
1820
}): ReadableStream<TextStreamPart<TOOLS>> {
1921
let canClose = false;
@@ -29,6 +31,9 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
2931
},
3032
});
3133

34+
// keep track of active tool calls
35+
const activeToolCalls: Record<string, boolean> = {};
36+
3237
// forward stream
3338
const forwardStream = new TransformStream<
3439
LanguageModelV1StreamPart,
@@ -48,6 +53,29 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
4853
break;
4954
}
5055

56+
// forward with less information:
57+
case 'tool-call-delta': {
58+
if (toolCallStreaming) {
59+
if (!activeToolCalls[chunk.toolCallId]) {
60+
controller.enqueue({
61+
type: 'tool-call-streaming-start',
62+
toolCallId: chunk.toolCallId,
63+
toolName: chunk.toolName,
64+
});
65+
66+
activeToolCalls[chunk.toolCallId] = true;
67+
}
68+
69+
controller.enqueue({
70+
type: 'tool-call-delta',
71+
toolCallId: chunk.toolCallId,
72+
toolName: chunk.toolName,
73+
argsTextDelta: chunk.argsTextDelta,
74+
});
75+
}
76+
break;
77+
}
78+
5179
// process tool call:
5280
case 'tool-call': {
5381
const toolName = chunk.toolName as keyof TOOLS & string;
@@ -162,11 +190,6 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
162190
break;
163191
}
164192

165-
// ignore
166-
case 'tool-call-delta': {
167-
break;
168-
}
169-
170193
default: {
171194
const _exhaustiveCheck: never = chunkType;
172195
throw new Error(`Unhandled chunk type: ${_exhaustiveCheck}`);

‎packages/core/core/generate-text/stream-text.test.ts

+410
Large diffs are not rendered by default.

‎packages/core/core/generate-text/stream-text.ts

+47-4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ export async function streamText<TOOLS extends Record<string, CoreTool>>({
8383
abortSignal,
8484
headers,
8585
experimental_telemetry: telemetry,
86+
experimental_toolCallStreaming: toolCallStreaming = false,
8687
onFinish,
8788
...settings
8889
}: CallSettings &
@@ -103,10 +104,15 @@ The tool choice strategy. Default: 'auto'.
103104
toolChoice?: CoreToolChoice<TOOLS>;
104105

105106
/**
106-
* Optional telemetry configuration (experimental).
107+
Optional telemetry configuration (experimental).
107108
*/
108109
experimental_telemetry?: TelemetrySettings;
109110

111+
/**
112+
Enable streaming of tool call deltas as they are generated. Disabled by default.
113+
*/
114+
experimental_toolCallStreaming?: boolean;
115+
110116
/**
111117
Callback that is called when the LLM response and all request tool executions
112118
(for tools that have an `execute` function) are finished.
@@ -212,6 +218,7 @@ Warnings from the model provider (e.g. unsupported settings).
212218
stream: runToolsTransformation({
213219
tools,
214220
generatorStream: stream,
221+
toolCallStreaming,
215222
tracer,
216223
}),
217224
warnings,
@@ -233,8 +240,15 @@ export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
233240
type: 'tool-call';
234241
} & ToToolCall<TOOLS>)
235242
| {
236-
type: 'error';
237-
error: unknown;
243+
type: 'tool-call-streaming-start';
244+
toolCallId: string;
245+
toolName: string;
246+
}
247+
| {
248+
type: 'tool-call-delta';
249+
toolCallId: string;
250+
toolName: string;
251+
argsTextDelta: string;
238252
}
239253
| ({
240254
type: 'tool-result';
@@ -248,6 +262,10 @@ export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
248262
completionTokens: number;
249263
totalTokens: number;
250264
};
265+
}
266+
| {
267+
type: 'error';
268+
error: unknown;
251269
};
252270

253271
/**
@@ -407,6 +425,8 @@ Response headers.
407425
resolveToolCalls(toolCalls);
408426
break;
409427

428+
case 'tool-call-streaming-start':
429+
case 'tool-call-delta':
410430
case 'error':
411431
// ignored
412432
break;
@@ -577,10 +597,27 @@ Stream callbacks that will be called when the stream emits events.
577597
string
578598
>({
579599
transform: async (chunk, controller) => {
580-
switch (chunk.type) {
600+
const chunkType = chunk.type;
601+
switch (chunkType) {
581602
case 'text-delta':
582603
controller.enqueue(formatStreamPart('text', chunk.textDelta));
583604
break;
605+
case 'tool-call-streaming-start':
606+
controller.enqueue(
607+
formatStreamPart('tool_call_streaming_start', {
608+
toolCallId: chunk.toolCallId,
609+
toolName: chunk.toolName,
610+
}),
611+
);
612+
break;
613+
case 'tool-call-delta':
614+
controller.enqueue(
615+
formatStreamPart('tool_call_delta', {
616+
toolCallId: chunk.toolCallId,
617+
argsTextDelta: chunk.argsTextDelta,
618+
}),
619+
);
620+
break;
584621
case 'tool-call':
585622
controller.enqueue(
586623
formatStreamPart('tool_call', {
@@ -605,6 +642,12 @@ Stream callbacks that will be called when the stream emits events.
605642
formatStreamPart('error', JSON.stringify(chunk.error)),
606643
);
607644
break;
645+
case 'finish':
646+
break; // ignored
647+
default: {
648+
const exhaustiveCheck: never = chunkType;
649+
throw new Error(`Unknown chunk type: ${exhaustiveCheck}`);
650+
}
608651
}
609652
},
610653
});

‎packages/core/core/types/language-model.ts

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import {
44
LanguageModelV1FinishReason,
55
LanguageModelV1LogProbs,
66
} from '@ai-sdk/provider';
7-
import { CoreTool } from '../tool/tool';
87

98
/**
109
Language model that is used by the AI SDK Core functions.

‎packages/react/src/use-chat.ui.test.tsx

+171-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
import { withTestServer } from '@ai-sdk/provider-utils/test';
33
import { formatStreamPart, getTextFromDataUrl } from '@ai-sdk/ui-utils';
44
import '@testing-library/jest-dom/vitest';
5-
import { cleanup, findByText, render, screen } from '@testing-library/react';
5+
import {
6+
RenderResult,
7+
cleanup,
8+
findByText,
9+
render,
10+
screen,
11+
waitFor,
12+
} from '@testing-library/react';
613
import userEvent from '@testing-library/user-event';
714
import React, { useRef, useState } from 'react';
815
import { useChat } from './use-chat';
@@ -447,6 +454,169 @@ describe('onToolCall', () => {
447454
);
448455
});
449456

457+
describe('tool invocations', () => {
458+
let rerender: RenderResult['rerender'];
459+
460+
const TestComponent = () => {
461+
const { messages, append } = useChat();
462+
463+
return (
464+
<div>
465+
{messages.map((m, idx) => (
466+
<div data-testid={`message-${idx}`} key={m.id}>
467+
{m.toolInvocations?.map((toolInvocation, toolIdx) => {
468+
return (
469+
<div key={toolIdx} data-testid={`tool-invocation-${toolIdx}`}>
470+
{JSON.stringify(toolInvocation)}
471+
</div>
472+
);
473+
})}
474+
</div>
475+
))}
476+
477+
<button
478+
data-testid="do-append"
479+
onClick={() => {
480+
append({ role: 'user', content: 'hi' });
481+
}}
482+
/>
483+
</div>
484+
);
485+
};
486+
487+
beforeEach(() => {
488+
const result = render(<TestComponent />);
489+
rerender = result.rerender;
490+
});
491+
492+
afterEach(() => {
493+
vi.restoreAllMocks();
494+
cleanup();
495+
});
496+
497+
it(
498+
'should display partial tool call, tool call, and tool result',
499+
withTestServer(
500+
{ url: '/api/chat', type: 'controlled-stream' },
501+
async ({ streamController }) => {
502+
await userEvent.click(screen.getByTestId('do-append'));
503+
504+
streamController.enqueue(
505+
formatStreamPart('tool_call_streaming_start', {
506+
toolCallId: 'tool-call-0',
507+
toolName: 'test-tool',
508+
}),
509+
);
510+
511+
await waitFor(() => {
512+
expect(screen.getByTestId('message-1')).toHaveTextContent(
513+
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool"}',
514+
);
515+
});
516+
517+
streamController.enqueue(
518+
formatStreamPart('tool_call_delta', {
519+
toolCallId: 'tool-call-0',
520+
argsTextDelta: '{"testArg":"t',
521+
}),
522+
);
523+
524+
await waitFor(() => {
525+
rerender(<TestComponent />);
526+
expect(screen.getByTestId('message-1')).toHaveTextContent(
527+
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"t"}}',
528+
);
529+
});
530+
531+
streamController.enqueue(
532+
formatStreamPart('tool_call_delta', {
533+
toolCallId: 'tool-call-0',
534+
argsTextDelta: 'est-value"}}',
535+
}),
536+
);
537+
538+
await waitFor(() => {
539+
rerender(<TestComponent />);
540+
expect(screen.getByTestId('message-1')).toHaveTextContent(
541+
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
542+
);
543+
});
544+
545+
streamController.enqueue(
546+
formatStreamPart('tool_call', {
547+
toolCallId: 'tool-call-0',
548+
toolName: 'test-tool',
549+
args: { testArg: 'test-value' },
550+
}),
551+
);
552+
553+
await waitFor(() => {
554+
rerender(<TestComponent />);
555+
expect(screen.getByTestId('message-1')).toHaveTextContent(
556+
'{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
557+
);
558+
});
559+
560+
streamController.enqueue(
561+
formatStreamPart('tool_result', {
562+
toolCallId: 'tool-call-0',
563+
toolName: 'test-tool',
564+
args: { testArg: 'test-value' },
565+
result: 'test-result',
566+
}),
567+
);
568+
streamController.close();
569+
570+
await waitFor(() => {
571+
expect(screen.getByTestId('message-1')).toHaveTextContent(
572+
'{"state":"result","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"},"result":"test-result"}',
573+
);
574+
});
575+
},
576+
),
577+
);
578+
579+
it(
580+
'should display partial tool call and tool result (when there is no tool call streaming)',
581+
withTestServer(
582+
{ url: '/api/chat', type: 'controlled-stream' },
583+
async ({ streamController }) => {
584+
await userEvent.click(screen.getByTestId('do-append'));
585+
586+
streamController.enqueue(
587+
formatStreamPart('tool_call', {
588+
toolCallId: 'tool-call-0',
589+
toolName: 'test-tool',
590+
args: { testArg: 'test-value' },
591+
}),
592+
);
593+
594+
await waitFor(() => {
595+
expect(screen.getByTestId('message-1')).toHaveTextContent(
596+
'{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
597+
);
598+
});
599+
600+
streamController.enqueue(
601+
formatStreamPart('tool_result', {
602+
toolCallId: 'tool-call-0',
603+
toolName: 'test-tool',
604+
args: { testArg: 'test-value' },
605+
result: 'test-result',
606+
}),
607+
);
608+
streamController.close();
609+
610+
await waitFor(() => {
611+
expect(screen.getByTestId('message-1')).toHaveTextContent(
612+
'{"state":"result","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"},"result":"test-result"}',
613+
);
614+
});
615+
},
616+
),
617+
);
618+
});
619+
450620
describe('maxToolRoundtrips', () => {
451621
describe('single automatic tool roundtrip', () => {
452622
let onToolCallInvoked = false;

‎packages/ui-utils/src/parse-complex-response.ts

+64-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { generateId as generateIdFunction } from '@ai-sdk/provider-utils';
2+
import { parsePartialJson } from './parse-partial-json';
23
import { readDataStream } from './read-data-stream';
34
import type {
45
FunctionCall,
@@ -58,6 +59,12 @@ export async function parseComplexResponse({
5859
// keep list of current message annotations for message
5960
let message_annotations: JSONValue[] | undefined = undefined;
6061

62+
// keep track of partial tool calls
63+
const partialToolCalls: Record<
64+
string,
65+
{ text: string; prefixMapIndex: number; toolName: string }
66+
> = {};
67+
6168
// we create a map of each prefix, and for each prefixed message we push to the map
6269
for await (const { type, value } of readDataStream(reader, {
6370
isAborted: () => abortControllerRef?.current === null,
@@ -79,7 +86,7 @@ export async function parseComplexResponse({
7986
}
8087

8188
// Tool invocations are part of an assistant message
82-
if (type === 'tool_call') {
89+
if (type === 'tool_call_streaming_start') {
8390
// create message if it doesn't exist
8491
if (prefixMap.text == null) {
8592
prefixMap.text = {
@@ -94,7 +101,56 @@ export async function parseComplexResponse({
94101
prefixMap.text.toolInvocations = [];
95102
}
96103

97-
prefixMap.text.toolInvocations.push(value);
104+
// add the partial tool call to the map
105+
partialToolCalls[value.toolCallId] = {
106+
text: '',
107+
toolName: value.toolName,
108+
prefixMapIndex: prefixMap.text.toolInvocations.length,
109+
};
110+
111+
prefixMap.text.toolInvocations.push({
112+
state: 'partial-call',
113+
toolCallId: value.toolCallId,
114+
toolName: value.toolName,
115+
args: undefined,
116+
});
117+
} else if (type === 'tool_call_delta') {
118+
const partialToolCall = partialToolCalls[value.toolCallId];
119+
120+
partialToolCall.text += value.argsTextDelta;
121+
122+
prefixMap.text!.toolInvocations![partialToolCall.prefixMapIndex] = {
123+
state: 'partial-call',
124+
toolCallId: value.toolCallId,
125+
toolName: partialToolCall.toolName,
126+
args: parsePartialJson(partialToolCall.text),
127+
};
128+
} else if (type === 'tool_call') {
129+
if (partialToolCalls[value.toolCallId] != null) {
130+
// change the partial tool call to a full tool call
131+
prefixMap.text!.toolInvocations![
132+
partialToolCalls[value.toolCallId].prefixMapIndex
133+
] = { state: 'call', ...value };
134+
} else {
135+
// create message if it doesn't exist
136+
if (prefixMap.text == null) {
137+
prefixMap.text = {
138+
id: generateId(),
139+
role: 'assistant',
140+
content: '',
141+
createdAt,
142+
};
143+
}
144+
145+
if (prefixMap.text.toolInvocations == null) {
146+
prefixMap.text.toolInvocations = [];
147+
}
148+
149+
prefixMap.text.toolInvocations.push({
150+
state: 'call',
151+
...value,
152+
});
153+
}
98154

99155
// invoke the onToolCall callback if it exists. This is blocking.
100156
// In the future we should make this non-blocking, which
@@ -103,9 +159,9 @@ export async function parseComplexResponse({
103159
const result = await onToolCall({ toolCall: value });
104160
if (result != null) {
105161
// store the result in the tool invocation
106-
prefixMap.text.toolInvocations[
107-
prefixMap.text.toolInvocations.length - 1
108-
] = { ...value, result };
162+
prefixMap.text!.toolInvocations![
163+
prefixMap.text!.toolInvocations!.length - 1
164+
] = { state: 'result', ...value, result };
109165
}
110166
}
111167
} else if (type === 'tool_result') {
@@ -129,10 +185,11 @@ export async function parseComplexResponse({
129185
invocation => invocation.toolCallId === value.toolCallId,
130186
);
131187

188+
const result = { state: 'result' as const, ...value };
132189
if (toolInvocationIndex !== -1) {
133-
prefixMap.text.toolInvocations[toolInvocationIndex] = value;
190+
prefixMap.text.toolInvocations[toolInvocationIndex] = result;
134191
} else {
135-
prefixMap.text.toolInvocations.push(value);
192+
prefixMap.text.toolInvocations.push(result);
136193
}
137194
}
138195

‎packages/ui-utils/src/stream-parts.test.ts

+39
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,42 @@ describe('tool_result stream part', () => {
156156
});
157157
});
158158
});
159+
160+
describe('tool_call_streaming_start stream part', () => {
161+
it('should format a tool_call_streaming_start stream part', () => {
162+
expect(
163+
formatStreamPart('tool_call_streaming_start', {
164+
toolCallId: 'tc_0',
165+
toolName: 'example_tool',
166+
}),
167+
).toEqual(`b:{"toolCallId":"tc_0","toolName":"example_tool"}\n`);
168+
});
169+
170+
it('should parse a tool_call_streaming_start stream part', () => {
171+
const input = `b:{"toolCallId":"tc_0","toolName":"example_tool"}`;
172+
173+
expect(parseStreamPart(input)).toEqual({
174+
type: 'tool_call_streaming_start',
175+
value: { toolCallId: 'tc_0', toolName: 'example_tool' },
176+
});
177+
});
178+
});
179+
180+
describe('tool_call_delta stream part', () => {
181+
it('should format a tool_call_delta stream part', () => {
182+
expect(
183+
formatStreamPart('tool_call_delta', {
184+
toolCallId: 'tc_0',
185+
argsTextDelta: 'delta',
186+
}),
187+
).toEqual(`c:{"toolCallId":"tc_0","argsTextDelta":"delta"}\n`);
188+
});
189+
190+
it('should parse a tool_call_delta stream part', () => {
191+
const input = `c:{"toolCallId":"tc_0","argsTextDelta":"delta"}`;
192+
expect(parseStreamPart(input)).toEqual({
193+
type: 'tool_call_delta',
194+
value: { toolCallId: 'tc_0', argsTextDelta: 'delta' },
195+
});
196+
});
197+
});

‎packages/ui-utils/src/stream-parts.ts

+72-2
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,65 @@ const toolResultStreamPart: StreamPart<
303303
},
304304
};
305305

306+
const toolCallStreamingStartStreamPart: StreamPart<
307+
'b',
308+
'tool_call_streaming_start',
309+
{ toolCallId: string; toolName: string }
310+
> = {
311+
code: 'b',
312+
name: 'tool_call_streaming_start',
313+
parse: (value: JSONValue) => {
314+
if (
315+
value == null ||
316+
typeof value !== 'object' ||
317+
!('toolCallId' in value) ||
318+
typeof value.toolCallId !== 'string' ||
319+
!('toolName' in value) ||
320+
typeof value.toolName !== 'string'
321+
) {
322+
throw new Error(
323+
'"tool_call_streaming_start" parts expect an object with a "toolCallId" and "toolName" property.',
324+
);
325+
}
326+
327+
return {
328+
type: 'tool_call_streaming_start',
329+
value: value as unknown as { toolCallId: string; toolName: string },
330+
};
331+
},
332+
};
333+
334+
const toolCallDeltaStreamPart: StreamPart<
335+
'c',
336+
'tool_call_delta',
337+
{ toolCallId: string; argsTextDelta: string }
338+
> = {
339+
code: 'c',
340+
name: 'tool_call_delta',
341+
parse: (value: JSONValue) => {
342+
if (
343+
value == null ||
344+
typeof value !== 'object' ||
345+
!('toolCallId' in value) ||
346+
typeof value.toolCallId !== 'string' ||
347+
!('argsTextDelta' in value) ||
348+
typeof value.argsTextDelta !== 'string'
349+
) {
350+
throw new Error(
351+
'"tool_call_delta" parts expect an object with a "toolCallId" and "argsTextDelta" property.',
352+
);
353+
}
354+
355+
return {
356+
type: 'tool_call_delta',
357+
value: value as unknown as {
358+
toolCallId: string;
359+
argsTextDelta: string;
360+
},
361+
};
362+
},
363+
};
364+
306365
const streamParts = [
307366
textStreamPart,
308367
functionCallStreamPart,
@@ -315,6 +374,8 @@ const streamParts = [
315374
messageAnnotationsStreamPart,
316375
toolCallStreamPart,
317376
toolResultStreamPart,
377+
toolCallStreamingStartStreamPart,
378+
toolCallDeltaStreamPart,
318379
] as const;
319380

320381
// union type of all stream parts
@@ -329,7 +390,9 @@ type StreamParts =
329390
| typeof toolCallsStreamPart
330391
| typeof messageAnnotationsStreamPart
331392
| typeof toolCallStreamPart
332-
| typeof toolResultStreamPart;
393+
| typeof toolResultStreamPart
394+
| typeof toolCallStreamingStartStreamPart
395+
| typeof toolCallDeltaStreamPart;
333396

334397
/**
335398
* Maps the type of a stream part to its value type.
@@ -349,7 +412,9 @@ export type StreamPartType =
349412
| ReturnType<typeof toolCallsStreamPart.parse>
350413
| ReturnType<typeof messageAnnotationsStreamPart.parse>
351414
| ReturnType<typeof toolCallStreamPart.parse>
352-
| ReturnType<typeof toolResultStreamPart.parse>;
415+
| ReturnType<typeof toolResultStreamPart.parse>
416+
| ReturnType<typeof toolCallStreamingStartStreamPart.parse>
417+
| ReturnType<typeof toolCallDeltaStreamPart.parse>;
353418

354419
export const streamPartsByCode = {
355420
[textStreamPart.code]: textStreamPart,
@@ -363,6 +428,8 @@ export const streamPartsByCode = {
363428
[messageAnnotationsStreamPart.code]: messageAnnotationsStreamPart,
364429
[toolCallStreamPart.code]: toolCallStreamPart,
365430
[toolResultStreamPart.code]: toolResultStreamPart,
431+
[toolCallStreamingStartStreamPart.code]: toolCallStreamingStartStreamPart,
432+
[toolCallDeltaStreamPart.code]: toolCallDeltaStreamPart,
366433
} as const;
367434

368435
/**
@@ -399,6 +466,9 @@ export const StreamStringPrefixes = {
399466
[messageAnnotationsStreamPart.name]: messageAnnotationsStreamPart.code,
400467
[toolCallStreamPart.name]: toolCallStreamPart.code,
401468
[toolResultStreamPart.name]: toolResultStreamPart.code,
469+
[toolCallStreamingStartStreamPart.name]:
470+
toolCallStreamingStartStreamPart.code,
471+
[toolCallDeltaStreamPart.name]: toolCallDeltaStreamPart.code,
402472
} as const;
403473

404474
export const validCodes = streamParts.map(part => part.code);

‎packages/ui-utils/src/types.ts

+13-2
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ there is one tool invocation. While the call is in progress, the invocation is a
114114
Once the call is complete, the invocation is a tool result.
115115
*/
116116
export type ToolInvocation =
117-
| CoreToolCall<string, any>
118-
| CoreToolResult<string, any, any>;
117+
| ({ state: 'partial-call' } & CoreToolCall<string, any>)
118+
| ({ state: 'call' } & CoreToolCall<string, any>)
119+
| ({ state: 'result' } & CoreToolResult<string, any, any>);
119120

120121
/**
121122
* An attachment that can be sent along with a message.
@@ -142,9 +143,19 @@ export interface Attachment {
142143
* AI SDK UI Messages. They are used in the client and to communicate between the frontend and the API routes.
143144
*/
144145
export interface Message {
146+
/**
147+
A unique identifier for the message.
148+
*/
145149
id: string;
150+
151+
/**
152+
The timestamp of the message.
153+
*/
146154
createdAt?: Date;
147155

156+
/**
157+
Text content of the message.
158+
*/
148159
content: string;
149160

150161
/**

0 commit comments

Comments
 (0)
Please sign in to comment.