|
2 | 2 | import { withTestServer } from '@ai-sdk/provider-utils/test';
|
3 | 3 | import { formatStreamPart, getTextFromDataUrl } from '@ai-sdk/ui-utils';
|
4 | 4 | import '@testing-library/jest-dom/vitest';
|
5 |
| -import { cleanup, findByText, render, screen } from '@testing-library/react'; |
| 5 | +import { |
| 6 | + RenderResult, |
| 7 | + cleanup, |
| 8 | + findByText, |
| 9 | + render, |
| 10 | + screen, |
| 11 | + waitFor, |
| 12 | +} from '@testing-library/react'; |
6 | 13 | import userEvent from '@testing-library/user-event';
|
7 | 14 | import React, { useRef, useState } from 'react';
|
8 | 15 | import { useChat } from './use-chat';
|
@@ -447,6 +454,169 @@ describe('onToolCall', () => {
|
447 | 454 | );
|
448 | 455 | });
|
449 | 456 |
|
| 457 | +describe('tool invocations', () => { |
| 458 | + let rerender: RenderResult['rerender']; |
| 459 | + |
| 460 | + const TestComponent = () => { |
| 461 | + const { messages, append } = useChat(); |
| 462 | + |
| 463 | + return ( |
| 464 | + <div> |
| 465 | + {messages.map((m, idx) => ( |
| 466 | + <div data-testid={`message-${idx}`} key={m.id}> |
| 467 | + {m.toolInvocations?.map((toolInvocation, toolIdx) => { |
| 468 | + return ( |
| 469 | + <div key={toolIdx} data-testid={`tool-invocation-${toolIdx}`}> |
| 470 | + {JSON.stringify(toolInvocation)} |
| 471 | + </div> |
| 472 | + ); |
| 473 | + })} |
| 474 | + </div> |
| 475 | + ))} |
| 476 | + |
| 477 | + <button |
| 478 | + data-testid="do-append" |
| 479 | + onClick={() => { |
| 480 | + append({ role: 'user', content: 'hi' }); |
| 481 | + }} |
| 482 | + /> |
| 483 | + </div> |
| 484 | + ); |
| 485 | + }; |
| 486 | + |
| 487 | + beforeEach(() => { |
| 488 | + const result = render(<TestComponent />); |
| 489 | + rerender = result.rerender; |
| 490 | + }); |
| 491 | + |
| 492 | + afterEach(() => { |
| 493 | + vi.restoreAllMocks(); |
| 494 | + cleanup(); |
| 495 | + }); |
| 496 | + |
| 497 | + it( |
| 498 | + 'should display partial tool call, tool call, and tool result', |
| 499 | + withTestServer( |
| 500 | + { url: '/api/chat', type: 'controlled-stream' }, |
| 501 | + async ({ streamController }) => { |
| 502 | + await userEvent.click(screen.getByTestId('do-append')); |
| 503 | + |
| 504 | + streamController.enqueue( |
| 505 | + formatStreamPart('tool_call_streaming_start', { |
| 506 | + toolCallId: 'tool-call-0', |
| 507 | + toolName: 'test-tool', |
| 508 | + }), |
| 509 | + ); |
| 510 | + |
| 511 | + await waitFor(() => { |
| 512 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 513 | + '{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool"}', |
| 514 | + ); |
| 515 | + }); |
| 516 | + |
| 517 | + streamController.enqueue( |
| 518 | + formatStreamPart('tool_call_delta', { |
| 519 | + toolCallId: 'tool-call-0', |
| 520 | + argsTextDelta: '{"testArg":"t', |
| 521 | + }), |
| 522 | + ); |
| 523 | + |
| 524 | + await waitFor(() => { |
| 525 | + rerender(<TestComponent />); |
| 526 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 527 | + '{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"t"}}', |
| 528 | + ); |
| 529 | + }); |
| 530 | + |
| 531 | + streamController.enqueue( |
| 532 | + formatStreamPart('tool_call_delta', { |
| 533 | + toolCallId: 'tool-call-0', |
| 534 | + argsTextDelta: 'est-value"}}', |
| 535 | + }), |
| 536 | + ); |
| 537 | + |
| 538 | + await waitFor(() => { |
| 539 | + rerender(<TestComponent />); |
| 540 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 541 | + '{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}', |
| 542 | + ); |
| 543 | + }); |
| 544 | + |
| 545 | + streamController.enqueue( |
| 546 | + formatStreamPart('tool_call', { |
| 547 | + toolCallId: 'tool-call-0', |
| 548 | + toolName: 'test-tool', |
| 549 | + args: { testArg: 'test-value' }, |
| 550 | + }), |
| 551 | + ); |
| 552 | + |
| 553 | + await waitFor(() => { |
| 554 | + rerender(<TestComponent />); |
| 555 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 556 | + '{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}', |
| 557 | + ); |
| 558 | + }); |
| 559 | + |
| 560 | + streamController.enqueue( |
| 561 | + formatStreamPart('tool_result', { |
| 562 | + toolCallId: 'tool-call-0', |
| 563 | + toolName: 'test-tool', |
| 564 | + args: { testArg: 'test-value' }, |
| 565 | + result: 'test-result', |
| 566 | + }), |
| 567 | + ); |
| 568 | + streamController.close(); |
| 569 | + |
| 570 | + await waitFor(() => { |
| 571 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 572 | + '{"state":"result","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"},"result":"test-result"}', |
| 573 | + ); |
| 574 | + }); |
| 575 | + }, |
| 576 | + ), |
| 577 | + ); |
| 578 | + |
| 579 | + it( |
| 580 | + 'should display partial tool call and tool result (when there is no tool call streaming)', |
| 581 | + withTestServer( |
| 582 | + { url: '/api/chat', type: 'controlled-stream' }, |
| 583 | + async ({ streamController }) => { |
| 584 | + await userEvent.click(screen.getByTestId('do-append')); |
| 585 | + |
| 586 | + streamController.enqueue( |
| 587 | + formatStreamPart('tool_call', { |
| 588 | + toolCallId: 'tool-call-0', |
| 589 | + toolName: 'test-tool', |
| 590 | + args: { testArg: 'test-value' }, |
| 591 | + }), |
| 592 | + ); |
| 593 | + |
| 594 | + await waitFor(() => { |
| 595 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 596 | + '{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}', |
| 597 | + ); |
| 598 | + }); |
| 599 | + |
| 600 | + streamController.enqueue( |
| 601 | + formatStreamPart('tool_result', { |
| 602 | + toolCallId: 'tool-call-0', |
| 603 | + toolName: 'test-tool', |
| 604 | + args: { testArg: 'test-value' }, |
| 605 | + result: 'test-result', |
| 606 | + }), |
| 607 | + ); |
| 608 | + streamController.close(); |
| 609 | + |
| 610 | + await waitFor(() => { |
| 611 | + expect(screen.getByTestId('message-1')).toHaveTextContent( |
| 612 | + '{"state":"result","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"},"result":"test-result"}', |
| 613 | + ); |
| 614 | + }); |
| 615 | + }, |
| 616 | + ), |
| 617 | + ); |
| 618 | +}); |
| 619 | + |
450 | 620 | describe('maxToolRoundtrips', () => {
|
451 | 621 | describe('single automatic tool roundtrip', () => {
|
452 | 622 | let onToolCallInvoked = false;
|
|
0 commit comments