Skip to content

Commit 5b7b3bb

Browse files
authoredJul 19, 2024··
fix (ai/ui): tool call streaming (#2345)
1 parent a44a8f3 commit 5b7b3bb

File tree

7 files changed

+141
-12
lines changed

7 files changed

+141
-12
lines changed
 

‎.changeset/eleven-zoos-dream.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
'@ai-sdk/ui-utils': patch
3+
'@ai-sdk/react': patch
4+
---
5+
6+
fix (ai/ui): tool call streaming
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { convertToCoreMessages, streamText } from 'ai';
3+
import { z } from 'zod';
4+
5+
// Allow streaming responses up to 30 seconds
6+
export const maxDuration = 30;
7+
8+
export async function POST(req: Request) {
9+
const { messages } = await req.json();
10+
11+
const result = await streamText({
12+
model: openai('gpt-4-turbo'),
13+
messages: convertToCoreMessages(messages),
14+
experimental_toolCallStreaming: true,
15+
system:
16+
'You are a helpful assistant that answers questions about the weather in a given city.' +
17+
'You use the showWeatherInformation tool to show the weather information to the user instead of talking about it.',
18+
tools: {
19+
// server-side tool with execute function:
20+
getWeatherInformation: {
21+
description: 'show the weather in a given city to the user',
22+
parameters: z.object({ city: z.string() }),
23+
execute: async ({}: { city: string }) => {
24+
const weatherOptions = ['sunny', 'cloudy', 'rainy', 'snowy', 'windy'];
25+
return {
26+
weather:
27+
weatherOptions[Math.floor(Math.random() * weatherOptions.length)],
28+
temperature: Math.floor(Math.random() * 50 - 10),
29+
};
30+
},
31+
},
32+
// client-side tool that displays whether information to the user:
33+
showWeatherInformation: {
34+
description:
35+
'Show the weather information to the user. Always use this tool to tell weather information to the user.',
36+
parameters: z.object({
37+
city: z.string(),
38+
weather: z.string(),
39+
temperature: z.number(),
40+
typicalWeather: z
41+
.string()
42+
.describe(
43+
'2-3 sentences about the typical weather in the city during spring.',
44+
),
45+
}),
46+
},
47+
},
48+
});
49+
50+
return result.toAIStreamResponse();
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
'use client';
2+
3+
import { ToolInvocation } from 'ai';
4+
import { Message, useChat } from 'ai/react';
5+
6+
export default function Chat() {
7+
const { messages, input, handleInputChange, handleSubmit } = useChat({
8+
api: '/api/use-chat-streaming-tool-calls',
9+
maxToolRoundtrips: 5,
10+
11+
// run client-side tools that are automatically executed:
12+
async onToolCall({ toolCall }) {
13+
if (toolCall.toolName === 'showWeatherInformation') {
14+
// display tool. add tool result that informs the llm that the tool was executed.
15+
return 'Weather information was shown to the user.';
16+
}
17+
},
18+
});
19+
20+
// used to only render the role when it changes:
21+
let lastRole: string | undefined = undefined;
22+
23+
return (
24+
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
25+
{messages?.map((m: Message) => {
26+
const isNewRole = m.role !== lastRole;
27+
lastRole = m.role;
28+
29+
return (
30+
<div key={m.id} className="whitespace-pre-wrap">
31+
{isNewRole && <strong>{`${m.role}: `}</strong>}
32+
{m.content}
33+
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
34+
const { toolCallId, args } = toolInvocation;
35+
36+
// render display weather tool calls:
37+
if (toolInvocation.toolName === 'showWeatherInformation') {
38+
return (
39+
<div
40+
key={toolCallId}
41+
className="p-4 my-2 text-gray-500 border border-gray-300 rounded"
42+
>
43+
<h4 className="mb-2">{args?.city ?? ''}</h4>
44+
<div className="flex flex-col gap-2">
45+
<div className="flex gap-2">
46+
{args?.weather && <b>{args.weather}</b>}
47+
{args?.temperature && <b>{args.temperature} &deg;C</b>}
48+
</div>
49+
{args?.typicalWeather && <div>{args.typicalWeather}</div>}
50+
</div>
51+
</div>
52+
);
53+
}
54+
})}
55+
</div>
56+
);
57+
})}
58+
59+
<form onSubmit={handleSubmit}>
60+
<input
61+
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
62+
value={input}
63+
placeholder="Say something..."
64+
onChange={handleInputChange}
65+
/>
66+
</form>
67+
</div>
68+
);
69+
}

‎packages/react/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"dependencies": {
2828
"@ai-sdk/provider-utils": "1.0.2",
2929
"@ai-sdk/ui-utils": "0.0.15",
30-
"swr": "2.2.0"
30+
"swr": "2.2.5"
3131
},
3232
"devDependencies": {
3333
"@testing-library/jest-dom": "^6.4.5",

‎packages/react/src/use-chat.ui.test.tsx

+1-7
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,6 @@ describe('onToolCall', () => {
455455
});
456456

457457
describe('tool invocations', () => {
458-
let rerender: RenderResult['rerender'];
459-
460458
const TestComponent = () => {
461459
const { messages, append } = useChat();
462460

@@ -485,8 +483,7 @@ describe('tool invocations', () => {
485483
};
486484

487485
beforeEach(() => {
488-
const result = render(<TestComponent />);
489-
rerender = result.rerender;
486+
render(<TestComponent />);
490487
});
491488

492489
afterEach(() => {
@@ -522,7 +519,6 @@ describe('tool invocations', () => {
522519
);
523520

524521
await waitFor(() => {
525-
rerender(<TestComponent />);
526522
expect(screen.getByTestId('message-1')).toHaveTextContent(
527523
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"t"}}',
528524
);
@@ -536,7 +532,6 @@ describe('tool invocations', () => {
536532
);
537533

538534
await waitFor(() => {
539-
rerender(<TestComponent />);
540535
expect(screen.getByTestId('message-1')).toHaveTextContent(
541536
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
542537
);
@@ -551,7 +546,6 @@ describe('tool invocations', () => {
551546
);
552547

553548
await waitFor(() => {
554-
rerender(<TestComponent />);
555549
expect(screen.getByTestId('message-1')).toHaveTextContent(
556550
'{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
557551
);

‎packages/ui-utils/src/parse-complex-response.ts

+8
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ export async function parseComplexResponse({
125125
toolName: partialToolCall.toolName,
126126
args: parsePartialJson(partialToolCall.text),
127127
};
128+
129+
// trigger update for streaming by copying adding a update id that changes
130+
// (without it, the changes get stuck in SWR and are not forwarded to rendering):
131+
(prefixMap.text! as any).internalUpdateId = generateId();
128132
} else if (type === 'tool_call') {
129133
if (partialToolCalls[value.toolCallId] != null) {
130134
// change the partial tool call to a full tool call
@@ -152,6 +156,10 @@ export async function parseComplexResponse({
152156
});
153157
}
154158

159+
// trigger update for streaming by copying adding a update id that changes
160+
// (without it, the changes get stuck in SWR and are not forwarded to rendering):
161+
(prefixMap.text! as any).internalUpdateId = generateId();
162+
155163
// invoke the onToolCall callback if it exists. This is blocking.
156164
// In the future we should make this non-blocking, which
157165
// requires additional state management for error handling etc.

‎pnpm-lock.yaml

+5-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.