Skip to content

Commit e7e5898

Browse files
authoredApr 24, 2024··
use-assistant: fix missing message content. (#1425)
1 parent 1e6339b commit e7e5898

File tree

5 files changed

+170
-24
lines changed

5 files changed

+170
-24
lines changed
 

‎.changeset/dirty-countries-learn.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
use-assistant: fix missing message content

‎examples/next-openai/app/assistant/page.tsx

+1-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ const roleToColorMap: Record<Message['role'], string> = {
1414

1515
export default function Chat() {
1616
const { status, messages, input, submitMessage, handleInputChange, error } =
17-
useAssistant({
18-
api: '/api/assistant',
19-
});
17+
useAssistant({ api: '/api/assistant' });
2018

2119
// When status changes to accepting messages, focus the input:
2220
const inputRef = useRef<HTMLInputElement>(null);

‎packages/core/react/use-assistant.ts

+19-19
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,26 @@ export function useAssistant({
153153

154154
setInput('');
155155

156-
const result = await fetch(api, {
157-
method: 'POST',
158-
credentials,
159-
headers: { 'Content-Type': 'application/json', ...headers },
160-
body: JSON.stringify({
161-
...body,
162-
// always use user-provided threadId when available:
163-
threadId: threadIdParam ?? threadId ?? null,
164-
message: input,
165-
166-
// optional request data:
167-
data: requestOptions?.data,
168-
}),
169-
});
170-
171-
if (result.body == null) {
172-
throw new Error('The response body is empty.');
173-
}
174-
175156
try {
157+
const result = await fetch(api, {
158+
method: 'POST',
159+
credentials,
160+
headers: { 'Content-Type': 'application/json', ...headers },
161+
body: JSON.stringify({
162+
...body,
163+
// always use user-provided threadId when available:
164+
threadId: threadIdParam ?? threadId ?? null,
165+
message: message.content,
166+
167+
// optional request data:
168+
data: requestOptions?.data,
169+
}),
170+
});
171+
172+
if (result.body == null) {
173+
throw new Error('The response body is empty.');
174+
}
175+
176176
for await (const { type, value } of readDataStream(
177177
result.body.getReader(),
178178
)) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import '@testing-library/jest-dom/vitest';
2+
import { cleanup, findByText, render, screen } from '@testing-library/react';
3+
import userEvent from '@testing-library/user-event';
4+
import { formatStreamPart } from '../streams';
5+
import {
6+
mockFetchDataStream,
7+
mockFetchDataStreamWithGenerator,
8+
} from '../tests/utils/mock-fetch';
9+
import { useAssistant } from './use-assistant';
10+
11+
describe('stream data stream', () => {
12+
const TestComponent = () => {
13+
const { status, messages, append } = useAssistant({
14+
api: '/api/assistant',
15+
});
16+
17+
return (
18+
<div>
19+
<div data-testid="status">{status}</div>
20+
{messages.map((m, idx) => (
21+
<div data-testid={`message-${idx}`} key={m.id}>
22+
{m.role === 'user' ? 'User: ' : 'AI: '}
23+
{m.content}
24+
</div>
25+
))}
26+
27+
<button
28+
data-testid="do-append"
29+
onClick={() => {
30+
append({ role: 'user', content: 'hi' });
31+
}}
32+
/>
33+
</div>
34+
);
35+
};
36+
37+
beforeEach(() => {
38+
render(<TestComponent />);
39+
});
40+
41+
afterEach(() => {
42+
vi.restoreAllMocks();
43+
cleanup();
44+
});
45+
46+
it('should show streamed response', async () => {
47+
const { requestBody } = mockFetchDataStream({
48+
url: 'https://example.com/api/assistant',
49+
chunks: [
50+
formatStreamPart('assistant_control_data', {
51+
threadId: 't0',
52+
messageId: 'm0',
53+
}),
54+
formatStreamPart('assistant_message', {
55+
id: 'm0',
56+
role: 'assistant',
57+
content: [{ type: 'text', text: { value: '' } }],
58+
}),
59+
// text parts:
60+
'0:"Hello"\n',
61+
'0:","\n',
62+
'0:" world"\n',
63+
'0:"."\n',
64+
],
65+
});
66+
67+
await userEvent.click(screen.getByTestId('do-append'));
68+
69+
await screen.findByTestId('message-0');
70+
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
71+
72+
await screen.findByTestId('message-1');
73+
expect(screen.getByTestId('message-1')).toHaveTextContent(
74+
'AI: Hello, world.',
75+
);
76+
77+
// check that correct information was sent to the server:
78+
expect(await requestBody).toStrictEqual(
79+
JSON.stringify({
80+
threadId: null,
81+
message: 'hi',
82+
}),
83+
);
84+
});
85+
86+
describe('loading state', () => {
87+
it('should show loading state', async () => {
88+
let finishGeneration: ((value?: unknown) => void) | undefined;
89+
const finishGenerationPromise = new Promise(resolve => {
90+
finishGeneration = resolve;
91+
});
92+
93+
mockFetchDataStreamWithGenerator({
94+
url: 'https://example.com/api/chat',
95+
chunkGenerator: (async function* generate() {
96+
const encoder = new TextEncoder();
97+
98+
yield encoder.encode(
99+
formatStreamPart('assistant_control_data', {
100+
threadId: 't0',
101+
messageId: 'm1',
102+
}),
103+
);
104+
105+
yield encoder.encode(
106+
formatStreamPart('assistant_message', {
107+
id: 'm1',
108+
role: 'assistant',
109+
content: [{ type: 'text', text: { value: '' } }],
110+
}),
111+
);
112+
113+
yield encoder.encode('0:"Hello"\n');
114+
115+
await finishGenerationPromise;
116+
})(),
117+
});
118+
119+
await userEvent.click(screen.getByTestId('do-append'));
120+
121+
await screen.findByTestId('status');
122+
expect(screen.getByTestId('status')).toHaveTextContent('in_progress');
123+
124+
finishGeneration?.();
125+
126+
await findByText(await screen.findByTestId('status'), 'awaiting_message');
127+
expect(screen.getByTestId('status')).toHaveTextContent(
128+
'awaiting_message',
129+
);
130+
});
131+
});
132+
});

‎packages/core/tests/utils/mock-fetch.ts

+13-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export function mockFetchDataStream({
5151
}
5252
}
5353

54-
mockFetchDataStreamWithGenerator({
54+
return mockFetchDataStreamWithGenerator({
5555
url,
5656
chunkGenerator: generateChunks(),
5757
});
@@ -64,7 +64,14 @@ export function mockFetchDataStreamWithGenerator({
6464
url: string;
6565
chunkGenerator: AsyncGenerator<Uint8Array, void, unknown>;
6666
}) {
67-
vi.spyOn(global, 'fetch').mockImplementation(async () => {
67+
let requestBodyResolve: ((value?: unknown) => void) | undefined;
68+
const requestBodyPromise = new Promise(resolve => {
69+
requestBodyResolve = resolve;
70+
});
71+
72+
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
73+
requestBodyResolve?.(init!.body as string);
74+
6875
return {
6976
url,
7077
ok: true,
@@ -83,6 +90,10 @@ export function mockFetchDataStreamWithGenerator({
8390
},
8491
} as unknown as Response;
8592
});
93+
94+
return {
95+
requestBody: requestBodyPromise,
96+
};
8697
}
8798

8899
export function mockFetchError({

0 commit comments

Comments
 (0)
Please sign in to comment.