@@ -8,6 +8,7 @@ import { cleanup, findByText, render, screen } from '@testing-library/react';
8
8
import userEvent from '@testing-library/user-event' ;
9
9
import React from 'react' ;
10
10
import { useChat } from './use-chat' ;
11
+ import { formatStreamPart } from '@ai-sdk/ui-utils' ;
11
12
12
13
describe ( 'stream data stream' , ( ) => {
13
14
const TestComponent = ( ) => {
@@ -207,3 +208,204 @@ describe('text stream', () => {
207
208
) ;
208
209
} ) ;
209
210
} ) ;
211
+
212
+ describe ( 'onToolCall' , ( ) => {
213
+ const TestComponent = ( ) => {
214
+ const { messages, append } = useChat ( {
215
+ async onToolCall ( { toolCall } ) {
216
+ return `test-tool-response: ${ toolCall . toolName } ${
217
+ toolCall . toolCallId
218
+ } ${ JSON . stringify ( toolCall . args ) } `;
219
+ } ,
220
+ } ) ;
221
+
222
+ return (
223
+ < div >
224
+ { messages . map ( ( m , idx ) => (
225
+ < div data-testid = { `message-${ idx } ` } key = { m . id } >
226
+ { m . toolInvocations ?. map ( ( toolInvocation , toolIdx ) =>
227
+ 'result' in toolInvocation ? (
228
+ < div key = { toolIdx } data-testid = { `tool-invocation-${ toolIdx } ` } >
229
+ { toolInvocation . result }
230
+ </ div >
231
+ ) : null ,
232
+ ) }
233
+ </ div >
234
+ ) ) }
235
+
236
+ < button
237
+ data-testid = "do-append"
238
+ onClick = { ( ) => {
239
+ append ( { role : 'user' , content : 'hi' } ) ;
240
+ } }
241
+ />
242
+ </ div >
243
+ ) ;
244
+ } ;
245
+
246
+ beforeEach ( ( ) => {
247
+ render ( < TestComponent /> ) ;
248
+ } ) ;
249
+
250
+ afterEach ( ( ) => {
251
+ vi . restoreAllMocks ( ) ;
252
+ cleanup ( ) ;
253
+ } ) ;
254
+
255
+ it ( "should invoke onToolCall when a tool call is received from the server's response" , async ( ) => {
256
+ mockFetchDataStream ( {
257
+ url : 'https://example.com/api/chat' ,
258
+ chunks : [
259
+ formatStreamPart ( 'tool_call' , {
260
+ toolCallId : 'tool-call-0' ,
261
+ toolName : 'test-tool' ,
262
+ args : { testArg : 'test-value' } ,
263
+ } ) ,
264
+ ] ,
265
+ } ) ;
266
+
267
+ await userEvent . click ( screen . getByTestId ( 'do-append' ) ) ;
268
+
269
+ await screen . findByTestId ( 'message-1' ) ;
270
+ expect ( screen . getByTestId ( 'message-1' ) ) . toHaveTextContent (
271
+ 'test-tool-response: test-tool tool-call-0 {"testArg":"test-value"}' ,
272
+ ) ;
273
+ } ) ;
274
+ } ) ;
275
+
276
+ describe ( 'maxToolRoundtrips' , ( ) => {
277
+ describe ( 'single automatic tool roundtrip' , ( ) => {
278
+ const TestComponent = ( ) => {
279
+ const { messages, append } = useChat ( {
280
+ async onToolCall ( { toolCall } ) {
281
+ mockFetchDataStream ( {
282
+ url : 'https://example.com/api/chat' ,
283
+ chunks : [ formatStreamPart ( 'text' , 'final result' ) ] ,
284
+ } ) ;
285
+
286
+ return `test-tool-response: ${ toolCall . toolName } ${
287
+ toolCall . toolCallId
288
+ } ${ JSON . stringify ( toolCall . args ) } `;
289
+ } ,
290
+ maxToolRoundtrips : 5 ,
291
+ } ) ;
292
+
293
+ return (
294
+ < div >
295
+ { messages . map ( ( m , idx ) => (
296
+ < div data-testid = { `message-${ idx } ` } key = { m . id } >
297
+ { m . content }
298
+ </ div >
299
+ ) ) }
300
+
301
+ < button
302
+ data-testid = "do-append"
303
+ onClick = { ( ) => {
304
+ append ( { role : 'user' , content : 'hi' } ) ;
305
+ } }
306
+ />
307
+ </ div >
308
+ ) ;
309
+ } ;
310
+
311
+ beforeEach ( ( ) => {
312
+ render ( < TestComponent /> ) ;
313
+ } ) ;
314
+
315
+ afterEach ( ( ) => {
316
+ vi . restoreAllMocks ( ) ;
317
+ cleanup ( ) ;
318
+ } ) ;
319
+
320
+ it ( 'should automatically call api when tool call gets executed via onToolCall' , async ( ) => {
321
+ mockFetchDataStream ( {
322
+ url : 'https://example.com/api/chat' ,
323
+ chunks : [
324
+ formatStreamPart ( 'tool_call' , {
325
+ toolCallId : 'tool-call-0' ,
326
+ toolName : 'test-tool' ,
327
+ args : { testArg : 'test-value' } ,
328
+ } ) ,
329
+ ] ,
330
+ } ) ;
331
+
332
+ await userEvent . click ( screen . getByTestId ( 'do-append' ) ) ;
333
+
334
+ await screen . findByTestId ( 'message-2' ) ;
335
+ expect ( screen . getByTestId ( 'message-2' ) ) . toHaveTextContent ( 'final result' ) ;
336
+ } ) ;
337
+ } ) ;
338
+
339
+ describe ( 'single roundtrip with error response' , ( ) => {
340
+ const TestComponent = ( ) => {
341
+ const { messages, append, error } = useChat ( {
342
+ async onToolCall ( { toolCall } ) {
343
+ mockFetchDataStream ( {
344
+ url : 'https://example.com/api/chat' ,
345
+ chunks : [ formatStreamPart ( 'error' , 'some failure' ) ] ,
346
+ maxCalls : 1 ,
347
+ } ) ;
348
+
349
+ return `test-tool-response: ${ toolCall . toolName } ${
350
+ toolCall . toolCallId
351
+ } ${ JSON . stringify ( toolCall . args ) } `;
352
+ } ,
353
+ maxToolRoundtrips : 5 ,
354
+ } ) ;
355
+
356
+ return (
357
+ < div >
358
+ { error && < div data-testid = "error" > { error . toString ( ) } </ div > }
359
+
360
+ { messages . map ( ( m , idx ) => (
361
+ < div data-testid = { `message-${ idx } ` } key = { m . id } >
362
+ { m . toolInvocations ?. map ( ( toolInvocation , toolIdx ) =>
363
+ 'result' in toolInvocation ? (
364
+ < div key = { toolIdx } data-testid = { `tool-invocation-${ toolIdx } ` } >
365
+ { toolInvocation . result }
366
+ </ div >
367
+ ) : null ,
368
+ ) }
369
+ </ div >
370
+ ) ) }
371
+
372
+ < button
373
+ data-testid = "do-append"
374
+ onClick = { ( ) => {
375
+ append ( { role : 'user' , content : 'hi' } ) ;
376
+ } }
377
+ />
378
+ </ div >
379
+ ) ;
380
+ } ;
381
+
382
+ beforeEach ( ( ) => {
383
+ render ( < TestComponent /> ) ;
384
+ } ) ;
385
+
386
+ afterEach ( ( ) => {
387
+ vi . restoreAllMocks ( ) ;
388
+ cleanup ( ) ;
389
+ } ) ;
390
+
391
+ it ( 'should automatically call api when tool call gets executed via onToolCall' , async ( ) => {
392
+ mockFetchDataStream ( {
393
+ url : 'https://example.com/api/chat' ,
394
+ chunks : [
395
+ formatStreamPart ( 'tool_call' , {
396
+ toolCallId : 'tool-call-0' ,
397
+ toolName : 'test-tool' ,
398
+ args : { testArg : 'test-value' } ,
399
+ } ) ,
400
+ ] ,
401
+ } ) ;
402
+
403
+ await userEvent . click ( screen . getByTestId ( 'do-append' ) ) ;
404
+
405
+ await screen . findByTestId ( 'error' ) ;
406
+ expect ( screen . getByTestId ( 'error' ) ) . toHaveTextContent (
407
+ 'Error: Too many calls' ,
408
+ ) ;
409
+ } ) ;
410
+ } ) ;
411
+ } ) ;
0 commit comments