Skip to content

Commit 6109c6a

Browse files
authoredMay 24, 2024··
feat (ai/react): add experimental_maxAutomaticRoundtrips to useChat (#1681)
1 parent dd7ee98 commit 6109c6a

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed
 

‎.changeset/great-bottles-burn.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai/react): add experimental_maxAutomaticRoundtrips to useChat
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { convertToCoreMessages, streamText } from 'ai';
3+
import { z } from 'zod';
4+
5+
export const dynamic = 'force-dynamic';
6+
export const maxDuration = 60;
7+
8+
export async function POST(req: Request) {
9+
const { messages } = await req.json();
10+
11+
const result = await streamText({
12+
model: openai('gpt-4-turbo'),
13+
system:
14+
'You are a weather bot that can use the weather tool to get the weather in a given city. ' +
15+
'Respond to the user with weather information in a friendly and helpful manner.',
16+
messages: convertToCoreMessages(messages),
17+
tools: {
18+
weather: {
19+
description: 'show the weather in a given city to the user',
20+
parameters: z.object({ city: z.string() }),
21+
execute: async ({}: { city: string }) => {
22+
// Random delay between 1000ms (1s) and 3000ms (3s):
23+
const delay = Math.floor(Math.random() * (3000 - 1000 + 1)) + 1000;
24+
await new Promise(resolve => setTimeout(resolve, delay));
25+
26+
// Random weather:
27+
const weatherOptions = ['sunny', 'cloudy', 'rainy', 'snowy', 'windy'];
28+
return weatherOptions[
29+
Math.floor(Math.random() * weatherOptions.length)
30+
];
31+
},
32+
},
33+
},
34+
});
35+
36+
return result.toAIStreamResponse();
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
'use client';
2+
3+
import { useChat } from 'ai/react';
4+
5+
export default function Chat() {
6+
const { messages, input, handleInputChange, handleSubmit } = useChat({
7+
api: '/api/use-chat-tool-result-roundtrip',
8+
experimental_maxAutomaticRoundtrips: 2,
9+
});
10+
11+
return (
12+
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
13+
{messages
14+
.filter(m => m.content) // filter out empty messages
15+
.map(m => (
16+
<div key={m.id} className="whitespace-pre-wrap">
17+
<strong>{`${m.role}: `}</strong>
18+
{m.content}
19+
<br />
20+
<br />
21+
</div>
22+
))}
23+
24+
<form onSubmit={handleSubmit}>
25+
<input
26+
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
27+
value={input}
28+
placeholder="Say something..."
29+
onChange={handleInputChange}
30+
/>
31+
</form>
32+
</div>
33+
);
34+
}

‎packages/core/react/use-chat.ts

+62-8
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ export function useChat({
219219
sendExtraMessageFields,
220220
experimental_onFunctionCall,
221221
experimental_onToolCall,
222+
experimental_maxAutomaticRoundtrips = 0,
222223
streamMode,
223224
onResponse,
224225
onFinish,
@@ -230,6 +231,19 @@ export function useChat({
230231
}: Omit<UseChatOptions, 'api'> & {
231232
api?: string | StreamingReactResponseAction;
232233
key?: string;
234+
/**
235+
Maximal number of automatic roundtrips for tool calls.
236+
237+
An automatic tool call roundtrip is a call to the server with the
238+
tool call results when all tool calls in the last assistant
239+
message have results.
240+
241+
A maximum number is required to prevent infinite loops in the
242+
case of misconfigured tools.
243+
244+
By default, it's set to 0, which will disable the feature.
245+
*/
246+
experimental_maxAutomaticRoundtrips?: number;
233247
} = {}): UseChatHelpers & {
234248
experimental_addToolResult: ({
235249
toolCallId,
@@ -343,6 +357,23 @@ export function useChat({
343357
} finally {
344358
mutateLoading(false);
345359
}
360+
361+
// auto-submit when all tool calls in the last assistant message have results:
362+
const messages = messagesRef.current;
363+
const lastMessage = messages[messages.length - 1];
364+
if (
365+
// ensure there is a last message:
366+
lastMessage != null &&
367+
// check if the feature is enabled:
368+
experimental_maxAutomaticRoundtrips > 0 &&
369+
// check that roundtrip is possible:
370+
isAssistantMessageWithCompletedToolCalls(lastMessage) &&
371+
// limit the number of automatic roundtrips:
372+
countTrailingAssistantMessages(messages) <=
373+
experimental_maxAutomaticRoundtrips
374+
) {
375+
await triggerRequest({ messages });
376+
}
346377
},
347378
[
348379
mutate,
@@ -359,6 +390,7 @@ export function useChat({
359390
sendExtraMessageFields,
360391
experimental_onFunctionCall,
361392
experimental_onToolCall,
393+
experimental_maxAutomaticRoundtrips,
362394
messagesRef,
363395
abortControllerRef,
364396
generateId,
@@ -526,16 +558,38 @@ export function useChat({
526558

527559
// auto-submit when all tool calls in the last assistant message have results:
528560
const lastMessage = updatedMessages[updatedMessages.length - 1];
529-
if (
530-
lastMessage.role === 'assistant' &&
531-
lastMessage.toolInvocations &&
532-
lastMessage.toolInvocations.length > 0 &&
533-
lastMessage.toolInvocations.every(
534-
toolInvocation => 'result' in toolInvocation,
535-
)
536-
) {
561+
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
537562
triggerRequest({ messages: updatedMessages });
538563
}
539564
},
540565
};
541566
}
567+
568+
/**
569+
Check if the message is an assistant message with completed tool calls.
570+
The message must have at least one tool invocation and all tool invocations
571+
must have a result.
572+
*/
573+
function isAssistantMessageWithCompletedToolCalls(message: Message) {
574+
return (
575+
message.role === 'assistant' &&
576+
message.toolInvocations &&
577+
message.toolInvocations.length > 0 &&
578+
message.toolInvocations.every(toolInvocation => 'result' in toolInvocation)
579+
);
580+
}
581+
582+
/**
583+
Returns the number of trailing assistant messages in the array.
584+
*/
585+
function countTrailingAssistantMessages(messages: Message[]) {
586+
let count = 0;
587+
for (let i = messages.length - 1; i >= 0; i--) {
588+
if (messages[i].role === 'assistant') {
589+
count++;
590+
} else {
591+
break;
592+
}
593+
}
594+
return count;
595+
}

0 commit comments

Comments
 (0)
Please sign in to comment.