Skip to content

Commit 70d1800

Browse files
authoredJul 8, 2024··
feat (ai/react): add setThreadId helper to switch between threads for useAssistant (#2209)
1 parent b218ef2 commit 70d1800

File tree

4 files changed

+220
-4
lines changed

4 files changed

+220
-4
lines changed
 

‎.changeset/plenty-ties-remember.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@ai-sdk/react': patch
3+
---
4+
5+
add setThreadId helper to switch between threads for useAssistant

‎content/docs/07-reference/ai-sdk-ui/20-use-assistant.mdx

+6
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ This works in conjunction with [`AssistantResponse`](./assistant-response) in th
107107
type: 'string | undefined',
108108
description: 'The current thread ID.',
109109
},
110+
{
111+
name: 'setThreadId',
112+
type: '(threadId: string | undefined) => void',
113+
description:
114+
"Set the current thread ID. Specifying a thread ID will switch to that thread, if it exists. If set to 'undefined', a new thread will be created. For both cases, `threadId` will be updated with the new value and `messages` will be cleared.",
115+
},
110116
{
111117
name: 'input',
112118
type: 'string',

‎packages/react/src/use-assistant.ts

+17-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ export type UseAssistantHelpers = {
2828
*/
2929
threadId: string | undefined;
3030

31+
/**
32+
* Set the current thread ID. Specifying a thread ID will switch to that thread, if it exists. If set to 'undefined', a new thread will be created. For both cases, `threadId` will be updated with the new value and `messages` will be cleared.
33+
*/
34+
setThreadId: (threadId: string | undefined) => void;
35+
3136
/**
3237
* The current value of the input field.
3338
*/
@@ -97,7 +102,9 @@ export function useAssistant({
97102
}: UseAssistantOptions): UseAssistantHelpers {
98103
const [messages, setMessages] = useState<Message[]>([]);
99104
const [input, setInput] = useState('');
100-
const [threadId, setThreadId] = useState<string | undefined>(undefined);
105+
const [currentThreadId, setCurrentThreadId] = useState<string | undefined>(
106+
undefined,
107+
);
101108
const [status, setStatus] = useState<AssistantStatus>('awaiting_message');
102109
const [error, setError] = useState<undefined | Error>(undefined);
103110

@@ -151,7 +158,7 @@ export function useAssistant({
151158
body: JSON.stringify({
152159
...body,
153160
// always use user-provided threadId when available:
154-
threadId: threadIdParam ?? threadId ?? null,
161+
threadId: threadIdParam ?? currentThreadId ?? null,
155162
message: message.content,
156163

157164
// optional request data:
@@ -216,7 +223,7 @@ export function useAssistant({
216223
}
217224

218225
case 'assistant_control_data': {
219-
setThreadId(value.threadId);
226+
setCurrentThreadId(value.threadId);
220227

221228
// set id of last message:
222229
setMessages(messages => {
@@ -267,11 +274,17 @@ export function useAssistant({
267274
append({ role: 'user', content: input }, requestOptions);
268275
};
269276

277+
const setThreadId = (threadId: string | undefined) => {
278+
setCurrentThreadId(threadId);
279+
setMessages([]);
280+
};
281+
270282
return {
271283
append,
272284
messages,
273285
setMessages,
274-
threadId,
286+
threadId: currentThreadId,
287+
setThreadId,
275288
input,
276289
setInput,
277290
handleInputChange,

‎packages/react/src/use-assistant.ui.test.tsx

+192
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,195 @@ describe('stream data stream', () => {
143143
});
144144
});
145145
});
146+
147+
describe('thread management', () => {
148+
const TestComponent = () => {
149+
const { status, messages, error, append, setThreadId, threadId } =
150+
useAssistant({
151+
api: '/api/assistant',
152+
});
153+
154+
return (
155+
<div>
156+
<div data-testid="status">{status}</div>
157+
<div data-testid="thread-id">{threadId || 'undefined'}</div>
158+
{error && <div data-testid="error">{error.toString()}</div>}
159+
{messages.map((m, idx) => (
160+
<div data-testid={`message-${idx}`} key={idx}>
161+
{m.role === 'user' ? 'User: ' : 'AI: '}
162+
{m.content}
163+
</div>
164+
))}
165+
166+
<button
167+
data-testid="do-append"
168+
onClick={() => {
169+
append({ role: 'user', content: 'hi' });
170+
}}
171+
/>
172+
<button
173+
data-testid="do-new-thread"
174+
onClick={() => {
175+
setThreadId(undefined);
176+
}}
177+
/>
178+
<button
179+
data-testid="do-thread-3"
180+
onClick={() => {
181+
setThreadId('t3');
182+
}}
183+
/>
184+
</div>
185+
);
186+
};
187+
188+
beforeEach(() => {
189+
render(<TestComponent />);
190+
});
191+
192+
afterEach(() => {
193+
vi.restoreAllMocks();
194+
cleanup();
195+
});
196+
197+
it('create new thread', async () => {
198+
await screen.findByTestId('thread-id');
199+
expect(screen.getByTestId('thread-id')).toHaveTextContent('undefined');
200+
});
201+
202+
it('should show streamed response', async () => {
203+
const { requestBody } = mockFetchDataStream({
204+
url: 'https://example.com/api/assistant',
205+
chunks: [
206+
formatStreamPart('assistant_control_data', {
207+
threadId: 't0',
208+
messageId: 'm0',
209+
}),
210+
formatStreamPart('assistant_message', {
211+
id: 'm0',
212+
role: 'assistant',
213+
content: [{ type: 'text', text: { value: '' } }],
214+
}),
215+
// text parts:
216+
'0:"Hello"\n',
217+
'0:","\n',
218+
'0:" world"\n',
219+
'0:"."\n',
220+
],
221+
});
222+
223+
await userEvent.click(screen.getByTestId('do-append'));
224+
225+
await screen.findByTestId('message-0');
226+
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
227+
228+
expect(screen.getByTestId('thread-id')).toHaveTextContent('t0');
229+
230+
await screen.findByTestId('message-1');
231+
expect(screen.getByTestId('message-1')).toHaveTextContent(
232+
'AI: Hello, world.',
233+
);
234+
235+
// check that correct information was sent to the server:
236+
expect(await requestBody).toStrictEqual(
237+
JSON.stringify({
238+
threadId: null,
239+
message: 'hi',
240+
}),
241+
);
242+
});
243+
244+
it('should switch to new thread on setting undefined threadId', async () => {
245+
await userEvent.click(screen.getByTestId('do-new-thread'));
246+
247+
expect(screen.queryByTestId('message-0')).toBeNull();
248+
expect(screen.queryByTestId('message-1')).toBeNull();
249+
250+
const { requestBody } = mockFetchDataStream({
251+
url: 'https://example.com/api/assistant',
252+
chunks: [
253+
formatStreamPart('assistant_control_data', {
254+
threadId: 't1',
255+
messageId: 'm0',
256+
}),
257+
formatStreamPart('assistant_message', {
258+
id: 'm0',
259+
role: 'assistant',
260+
content: [{ type: 'text', text: { value: '' } }],
261+
}),
262+
// text parts:
263+
'0:"Hello"\n',
264+
'0:","\n',
265+
'0:" world"\n',
266+
'0:"."\n',
267+
],
268+
});
269+
270+
await userEvent.click(screen.getByTestId('do-append'));
271+
272+
await screen.findByTestId('message-0');
273+
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
274+
275+
expect(screen.getByTestId('thread-id')).toHaveTextContent('t1');
276+
277+
await screen.findByTestId('message-1');
278+
expect(screen.getByTestId('message-1')).toHaveTextContent(
279+
'AI: Hello, world.',
280+
);
281+
282+
// check that correct information was sent to the server:
283+
expect(await requestBody).toStrictEqual(
284+
JSON.stringify({
285+
threadId: null,
286+
message: 'hi',
287+
}),
288+
);
289+
});
290+
291+
it('should switch to thread on setting previously created threadId', async () => {
292+
await userEvent.click(screen.getByTestId('do-thread-3'));
293+
294+
expect(screen.queryByTestId('message-0')).toBeNull();
295+
expect(screen.queryByTestId('message-1')).toBeNull();
296+
297+
const { requestBody } = mockFetchDataStream({
298+
url: 'https://example.com/api/assistant',
299+
chunks: [
300+
formatStreamPart('assistant_control_data', {
301+
threadId: 't3',
302+
messageId: 'm0',
303+
}),
304+
formatStreamPart('assistant_message', {
305+
id: 'm0',
306+
role: 'assistant',
307+
content: [{ type: 'text', text: { value: '' } }],
308+
}),
309+
// text parts:
310+
'0:"Hello"\n',
311+
'0:","\n',
312+
'0:" world"\n',
313+
'0:"."\n',
314+
],
315+
});
316+
317+
await userEvent.click(screen.getByTestId('do-append'));
318+
319+
await screen.findByTestId('message-0');
320+
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
321+
322+
expect(screen.getByTestId('thread-id')).toHaveTextContent('t3');
323+
324+
await screen.findByTestId('message-1');
325+
expect(screen.getByTestId('message-1')).toHaveTextContent(
326+
'AI: Hello, world.',
327+
);
328+
329+
// check that correct information was sent to the server:
330+
expect(await requestBody).toStrictEqual(
331+
JSON.stringify({
332+
threadId: 't3',
333+
message: 'hi',
334+
}),
335+
);
336+
});
337+
});

0 commit comments

Comments
 (0)
Please sign in to comment.