Skip to content

Commit 3cb103b

Browse files
authoredJun 21, 2024··
fix (ai/react): prevent infinite tool call loop (#2052)
1 parent 605cada commit 3cb103b

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed
 

‎.changeset/yellow-foxes-jump.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@ai-sdk/react': patch
3+
---
4+
5+
fix (ai/react): prevent infinite tool call loop

‎packages/react/src/use-chat.ts

+4
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ By default, it's set to 0, which will disable the feature.
282282

283283
const triggerRequest = useCallback(
284284
async (chatRequest: ChatRequest) => {
285+
const messageCount = messagesRef.current.length;
286+
285287
try {
286288
mutateLoading(true);
287289
setError(undefined);
@@ -336,6 +338,8 @@ By default, it's set to 0, which will disable the feature.
336338
const messages = messagesRef.current;
337339
const lastMessage = messages[messages.length - 1];
338340
if (
341+
// ensure we actually have new messages (to prevent infinite loops in case of errors):
342+
messages.length > messageCount &&
339343
// ensure there is a last message:
340344
lastMessage != null &&
341345
// check if the feature is enabled:

‎packages/react/src/use-chat.ui.test.tsx

+202
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { cleanup, findByText, render, screen } from '@testing-library/react';
88
import userEvent from '@testing-library/user-event';
99
import React from 'react';
1010
import { useChat } from './use-chat';
11+
import { formatStreamPart } from '@ai-sdk/ui-utils';
1112

1213
describe('stream data stream', () => {
1314
const TestComponent = () => {
@@ -207,3 +208,204 @@ describe('text stream', () => {
207208
);
208209
});
209210
});
211+
212+
describe('onToolCall', () => {
213+
const TestComponent = () => {
214+
const { messages, append } = useChat({
215+
async onToolCall({ toolCall }) {
216+
return `test-tool-response: ${toolCall.toolName} ${
217+
toolCall.toolCallId
218+
} ${JSON.stringify(toolCall.args)}`;
219+
},
220+
});
221+
222+
return (
223+
<div>
224+
{messages.map((m, idx) => (
225+
<div data-testid={`message-${idx}`} key={m.id}>
226+
{m.toolInvocations?.map((toolInvocation, toolIdx) =>
227+
'result' in toolInvocation ? (
228+
<div key={toolIdx} data-testid={`tool-invocation-${toolIdx}`}>
229+
{toolInvocation.result}
230+
</div>
231+
) : null,
232+
)}
233+
</div>
234+
))}
235+
236+
<button
237+
data-testid="do-append"
238+
onClick={() => {
239+
append({ role: 'user', content: 'hi' });
240+
}}
241+
/>
242+
</div>
243+
);
244+
};
245+
246+
beforeEach(() => {
247+
render(<TestComponent />);
248+
});
249+
250+
afterEach(() => {
251+
vi.restoreAllMocks();
252+
cleanup();
253+
});
254+
255+
it("should invoke onToolCall when a tool call is received from the server's response", async () => {
256+
mockFetchDataStream({
257+
url: 'https://example.com/api/chat',
258+
chunks: [
259+
formatStreamPart('tool_call', {
260+
toolCallId: 'tool-call-0',
261+
toolName: 'test-tool',
262+
args: { testArg: 'test-value' },
263+
}),
264+
],
265+
});
266+
267+
await userEvent.click(screen.getByTestId('do-append'));
268+
269+
await screen.findByTestId('message-1');
270+
expect(screen.getByTestId('message-1')).toHaveTextContent(
271+
'test-tool-response: test-tool tool-call-0 {"testArg":"test-value"}',
272+
);
273+
});
274+
});
275+
276+
describe('maxToolRoundtrips', () => {
277+
describe('single automatic tool roundtrip', () => {
278+
const TestComponent = () => {
279+
const { messages, append } = useChat({
280+
async onToolCall({ toolCall }) {
281+
mockFetchDataStream({
282+
url: 'https://example.com/api/chat',
283+
chunks: [formatStreamPart('text', 'final result')],
284+
});
285+
286+
return `test-tool-response: ${toolCall.toolName} ${
287+
toolCall.toolCallId
288+
} ${JSON.stringify(toolCall.args)}`;
289+
},
290+
maxToolRoundtrips: 5,
291+
});
292+
293+
return (
294+
<div>
295+
{messages.map((m, idx) => (
296+
<div data-testid={`message-${idx}`} key={m.id}>
297+
{m.content}
298+
</div>
299+
))}
300+
301+
<button
302+
data-testid="do-append"
303+
onClick={() => {
304+
append({ role: 'user', content: 'hi' });
305+
}}
306+
/>
307+
</div>
308+
);
309+
};
310+
311+
beforeEach(() => {
312+
render(<TestComponent />);
313+
});
314+
315+
afterEach(() => {
316+
vi.restoreAllMocks();
317+
cleanup();
318+
});
319+
320+
it('should automatically call api when tool call gets executed via onToolCall', async () => {
321+
mockFetchDataStream({
322+
url: 'https://example.com/api/chat',
323+
chunks: [
324+
formatStreamPart('tool_call', {
325+
toolCallId: 'tool-call-0',
326+
toolName: 'test-tool',
327+
args: { testArg: 'test-value' },
328+
}),
329+
],
330+
});
331+
332+
await userEvent.click(screen.getByTestId('do-append'));
333+
334+
await screen.findByTestId('message-2');
335+
expect(screen.getByTestId('message-2')).toHaveTextContent('final result');
336+
});
337+
});
338+
339+
describe('single roundtrip with error response', () => {
340+
const TestComponent = () => {
341+
const { messages, append, error } = useChat({
342+
async onToolCall({ toolCall }) {
343+
mockFetchDataStream({
344+
url: 'https://example.com/api/chat',
345+
chunks: [formatStreamPart('error', 'some failure')],
346+
maxCalls: 1,
347+
});
348+
349+
return `test-tool-response: ${toolCall.toolName} ${
350+
toolCall.toolCallId
351+
} ${JSON.stringify(toolCall.args)}`;
352+
},
353+
maxToolRoundtrips: 5,
354+
});
355+
356+
return (
357+
<div>
358+
{error && <div data-testid="error">{error.toString()}</div>}
359+
360+
{messages.map((m, idx) => (
361+
<div data-testid={`message-${idx}`} key={m.id}>
362+
{m.toolInvocations?.map((toolInvocation, toolIdx) =>
363+
'result' in toolInvocation ? (
364+
<div key={toolIdx} data-testid={`tool-invocation-${toolIdx}`}>
365+
{toolInvocation.result}
366+
</div>
367+
) : null,
368+
)}
369+
</div>
370+
))}
371+
372+
<button
373+
data-testid="do-append"
374+
onClick={() => {
375+
append({ role: 'user', content: 'hi' });
376+
}}
377+
/>
378+
</div>
379+
);
380+
};
381+
382+
beforeEach(() => {
383+
render(<TestComponent />);
384+
});
385+
386+
afterEach(() => {
387+
vi.restoreAllMocks();
388+
cleanup();
389+
});
390+
391+
it('should automatically call api when tool call gets executed via onToolCall', async () => {
392+
mockFetchDataStream({
393+
url: 'https://example.com/api/chat',
394+
chunks: [
395+
formatStreamPart('tool_call', {
396+
toolCallId: 'tool-call-0',
397+
toolName: 'test-tool',
398+
args: { testArg: 'test-value' },
399+
}),
400+
],
401+
});
402+
403+
await userEvent.click(screen.getByTestId('do-append'));
404+
405+
await screen.findByTestId('error');
406+
expect(screen.getByTestId('error')).toHaveTextContent(
407+
'Error: Too many calls',
408+
);
409+
});
410+
});
411+
});

‎packages/ui-utils/src/test/mock-fetch.ts

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { fail } from 'node:assert';
12
import { vi } from 'vitest';
23

34
export function mockFetchTextStream({
@@ -40,9 +41,11 @@ export function mockFetchTextStream({
4041
export function mockFetchDataStream({
4142
url,
4243
chunks,
44+
maxCalls,
4345
}: {
4446
url: string;
4547
chunks: string[];
48+
maxCalls?: number;
4649
}) {
4750
async function* generateChunks() {
4851
const encoder = new TextEncoder();
@@ -54,22 +57,31 @@ export function mockFetchDataStream({
5457
return mockFetchDataStreamWithGenerator({
5558
url,
5659
chunkGenerator: generateChunks(),
60+
maxCalls,
5761
});
5862
}
5963

6064
export function mockFetchDataStreamWithGenerator({
6165
url,
6266
chunkGenerator,
67+
maxCalls,
6368
}: {
6469
url: string;
6570
chunkGenerator: AsyncGenerator<Uint8Array, void, unknown>;
71+
maxCalls?: number;
6672
}) {
6773
let requestBodyResolve: ((value?: unknown) => void) | undefined;
6874
const requestBodyPromise = new Promise(resolve => {
6975
requestBodyResolve = resolve;
7076
});
7177

78+
let callCount = 0;
79+
7280
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
81+
if (maxCalls !== undefined && ++callCount >= maxCalls) {
82+
throw new Error('Too many calls');
83+
}
84+
7385
requestBodyResolve?.(init!.body as string);
7486

7587
return {

0 commit comments

Comments
 (0)
Please sign in to comment.