Skip to content

Commit ce009e2

Browse files
authoredMar 13, 2024··
OpenAI Assistants streaming (#1146)
1 parent 397a89b commit ce009e2

File tree

8 files changed

+558
-226
lines changed

8 files changed

+558
-226
lines changed
 

‎.changeset/nine-planets-flow.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
Added OpenAI assistants streaming.

‎docs/pages/docs/api-reference/providers/assistant-response.mdx

+69-106
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ The process parameter is a callback in which you can run the assistant on thread
3535

3636
It gets invoked with the following functions that you can use to send messages and data messages to the client:
3737

38-
- `sendMessage: (message: AssistantMessage) => void`: Sends an assistant message to the client.
38+
- `forwardStream: (stream: AssistantStream) => void`: Forwards the assistant response stream to the client.
3939
- `sendDataMessage: (message: DataMessage) => void`: Send a data message to the client. You can use this to provide information for rendering custom UIs while the assistant is processing the thread.
4040

4141
## Example
@@ -50,7 +50,6 @@ Server:
5050
```tsx filename="app/api/assistant/route.ts"
5151
import { experimental_AssistantResponse } from 'ai';
5252
import OpenAI from 'openai';
53-
import { MessageContentText } from 'openai/resources/beta/threads/messages/messages';
5453

5554
// Create an OpenAI API client (that's edge friendly!)
5655
const openai = new OpenAI({
@@ -86,118 +85,82 @@ export async function POST(req: Request) {
8685

8786
return experimental_AssistantResponse(
8887
{ threadId, messageId: createdMessage.id },
89-
async ({ threadId, sendMessage, sendDataMessage }) => {
88+
async ({ forwardStream, sendDataMessage }) => {
9089
// Run the assistant on the thread
91-
const run = await openai.beta.threads.runs.create(threadId, {
90+
const runStream = openai.beta.threads.runs.createAndStream(threadId, {
9291
assistant_id:
9392
process.env.ASSISTANT_ID ??
9493
(() => {
9594
throw new Error('ASSISTANT_ID is not set');
9695
})(),
9796
});
9897

99-
async function waitForRun(run: OpenAI.Beta.Threads.Runs.Run) {
100-
// Poll for status change
101-
while (run.status === 'queued' || run.status === 'in_progress') {
102-
// delay for 500ms:
103-
await new Promise(resolve => setTimeout(resolve, 500));
104-
105-
run = await openai.beta.threads.runs.retrieve(threadId!, run.id);
106-
}
107-
108-
// Check the run status
109-
if (
110-
run.status === 'cancelled' ||
111-
run.status === 'cancelling' ||
112-
run.status === 'failed' ||
113-
run.status === 'expired'
114-
) {
115-
throw new Error(run.status);
116-
}
117-
118-
if (run.status === 'requires_action') {
119-
if (run.required_action?.type === 'submit_tool_outputs') {
120-
const tool_outputs =
121-
run.required_action.submit_tool_outputs.tool_calls.map(
122-
toolCall => {
123-
const parameters = JSON.parse(toolCall.function.arguments);
124-
125-
switch (toolCall.function.name) {
126-
case 'getRoomTemperature': {
127-
const temperature =
128-
homeTemperatures[
129-
parameters.room as keyof typeof homeTemperatures
130-
];
131-
132-
return {
133-
tool_call_id: toolCall.id,
134-
output: temperature.toString(),
135-
};
136-
}
137-
138-
case 'setRoomTemperature': {
139-
const oldTemperature =
140-
homeTemperatures[
141-
parameters.room as keyof typeof homeTemperatures
142-
];
143-
144-
homeTemperatures[
145-
parameters.room as keyof typeof homeTemperatures
146-
] = parameters.temperature;
147-
148-
sendDataMessage({
149-
role: 'data',
150-
data: {
151-
oldTemperature,
152-
newTemperature: parameters.temperature,
153-
description: `Temperature in ${parameters.room} changed from ${oldTemperature} to ${parameters.temperature}`,
154-
},
155-
});
156-
157-
return {
158-
tool_call_id: toolCall.id,
159-
output: `temperature set successfully`,
160-
};
161-
}
162-
163-
default:
164-
throw new Error(
165-
`Unknown tool call function: ${toolCall.function.name}`,
166-
);
167-
}
168-
},
169-
);
170-
171-
run = await openai.beta.threads.runs.submitToolOutputs(
172-
threadId!,
173-
run.id,
174-
{ tool_outputs },
175-
);
176-
177-
await waitForRun(run);
178-
}
179-
}
180-
}
181-
182-
await waitForRun(run);
183-
184-
// Get new thread messages (after our message)
185-
const responseMessages = (
186-
await openai.beta.threads.messages.list(threadId, {
187-
after: createdMessage.id,
188-
order: 'asc',
189-
})
190-
).data;
191-
192-
// Send the messages
193-
for (const message of responseMessages) {
194-
sendMessage({
195-
id: message.id,
196-
role: 'assistant',
197-
content: message.content.filter(
198-
content => content.type === 'text',
199-
) as Array<MessageContentText>,
200-
});
98+
// forward run status would stream message deltas
99+
let runResult = await forwardStream(runStream);
100+
101+
// status can be: queued, in_progress, requires_action, cancelling, cancelled, failed, completed, or expired
102+
while (
103+
runResult.status === 'requires_action' &&
104+
runResult.required_action?.type === 'submit_tool_outputs'
105+
) {
106+
const tool_outputs =
107+
runResult.required_action.submit_tool_outputs.tool_calls.map(
108+
(toolCall: any) => {
109+
const parameters = JSON.parse(toolCall.function.arguments);
110+
111+
switch (toolCall.function.name) {
112+
case 'getRoomTemperature': {
113+
const temperature =
114+
homeTemperatures[
115+
parameters.room as keyof typeof homeTemperatures
116+
];
117+
118+
return {
119+
tool_call_id: toolCall.id,
120+
output: temperature.toString(),
121+
};
122+
}
123+
124+
case 'setRoomTemperature': {
125+
const oldTemperature =
126+
homeTemperatures[
127+
parameters.room as keyof typeof homeTemperatures
128+
];
129+
130+
homeTemperatures[
131+
parameters.room as keyof typeof homeTemperatures
132+
] = parameters.temperature;
133+
134+
sendDataMessage({
135+
role: 'data',
136+
data: {
137+
oldTemperature,
138+
newTemperature: parameters.temperature,
139+
description: `Temperature in ${parameters.room} changed from ${oldTemperature} to ${parameters.temperature}`,
140+
},
141+
});
142+
143+
return {
144+
tool_call_id: toolCall.id,
145+
output: `temperature set successfully`,
146+
};
147+
}
148+
149+
default:
150+
throw new Error(
151+
`Unknown tool call function: ${toolCall.function.name}`,
152+
);
153+
}
154+
},
155+
);
156+
157+
runResult = await forwardStream(
158+
openai.beta.threads.runs.submitToolOutputsStream(
159+
threadId,
160+
runResult.id,
161+
{ tool_outputs },
162+
),
163+
);
201164
}
202165
},
203166
);

‎examples/next-openai/app/api/assistant/route.ts

+68-105
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { experimental_AssistantResponse } from 'ai';
22
import OpenAI from 'openai';
3-
import { MessageContentText } from 'openai/resources/beta/threads/messages/messages';
43

54
// Create an OpenAI API client (that's edge friendly!)
65
const openai = new OpenAI({
@@ -36,118 +35,82 @@ export async function POST(req: Request) {
3635

3736
return experimental_AssistantResponse(
3837
{ threadId, messageId: createdMessage.id },
39-
async ({ threadId, sendMessage, sendDataMessage }) => {
38+
async ({ forwardStream, sendDataMessage }) => {
4039
// Run the assistant on the thread
41-
const run = await openai.beta.threads.runs.create(threadId, {
40+
const runStream = openai.beta.threads.runs.createAndStream(threadId, {
4241
assistant_id:
4342
process.env.ASSISTANT_ID ??
4443
(() => {
4544
throw new Error('ASSISTANT_ID is not set');
4645
})(),
4746
});
4847

49-
async function waitForRun(run: OpenAI.Beta.Threads.Runs.Run) {
50-
// Poll for status change
51-
while (run.status === 'queued' || run.status === 'in_progress') {
52-
// delay for 500ms:
53-
await new Promise(resolve => setTimeout(resolve, 500));
54-
55-
run = await openai.beta.threads.runs.retrieve(threadId!, run.id);
56-
}
57-
58-
// Check the run status
59-
if (
60-
run.status === 'cancelled' ||
61-
run.status === 'cancelling' ||
62-
run.status === 'failed' ||
63-
run.status === 'expired'
64-
) {
65-
throw new Error(run.status);
66-
}
67-
68-
if (run.status === 'requires_action') {
69-
if (run.required_action?.type === 'submit_tool_outputs') {
70-
const tool_outputs =
71-
run.required_action.submit_tool_outputs.tool_calls.map(
72-
toolCall => {
73-
const parameters = JSON.parse(toolCall.function.arguments);
74-
75-
switch (toolCall.function.name) {
76-
case 'getRoomTemperature': {
77-
const temperature =
78-
homeTemperatures[
79-
parameters.room as keyof typeof homeTemperatures
80-
];
81-
82-
return {
83-
tool_call_id: toolCall.id,
84-
output: temperature.toString(),
85-
};
86-
}
87-
88-
case 'setRoomTemperature': {
89-
const oldTemperature =
90-
homeTemperatures[
91-
parameters.room as keyof typeof homeTemperatures
92-
];
93-
94-
homeTemperatures[
95-
parameters.room as keyof typeof homeTemperatures
96-
] = parameters.temperature;
97-
98-
sendDataMessage({
99-
role: 'data',
100-
data: {
101-
oldTemperature,
102-
newTemperature: parameters.temperature,
103-
description: `Temperature in ${parameters.room} changed from ${oldTemperature} to ${parameters.temperature}`,
104-
},
105-
});
106-
107-
return {
108-
tool_call_id: toolCall.id,
109-
output: `temperature set successfully`,
110-
};
111-
}
112-
113-
default:
114-
throw new Error(
115-
`Unknown tool call function: ${toolCall.function.name}`,
116-
);
117-
}
118-
},
119-
);
120-
121-
run = await openai.beta.threads.runs.submitToolOutputs(
122-
threadId!,
123-
run.id,
124-
{ tool_outputs },
125-
);
126-
127-
await waitForRun(run);
128-
}
129-
}
130-
}
131-
132-
await waitForRun(run);
133-
134-
// Get new thread messages (after our message)
135-
const responseMessages = (
136-
await openai.beta.threads.messages.list(threadId, {
137-
after: createdMessage.id,
138-
order: 'asc',
139-
})
140-
).data;
141-
142-
// Send the messages
143-
for (const message of responseMessages) {
144-
sendMessage({
145-
id: message.id,
146-
role: 'assistant',
147-
content: message.content.filter(
148-
content => content.type === 'text',
149-
) as Array<MessageContentText>,
150-
});
48+
// forward run status would stream message deltas
49+
let runResult = await forwardStream(runStream);
50+
51+
// status can be: queued, in_progress, requires_action, cancelling, cancelled, failed, completed, or expired
52+
while (
53+
runResult.status === 'requires_action' &&
54+
runResult.required_action?.type === 'submit_tool_outputs'
55+
) {
56+
const tool_outputs =
57+
runResult.required_action.submit_tool_outputs.tool_calls.map(
58+
(toolCall: any) => {
59+
const parameters = JSON.parse(toolCall.function.arguments);
60+
61+
switch (toolCall.function.name) {
62+
case 'getRoomTemperature': {
63+
const temperature =
64+
homeTemperatures[
65+
parameters.room as keyof typeof homeTemperatures
66+
];
67+
68+
return {
69+
tool_call_id: toolCall.id,
70+
output: temperature.toString(),
71+
};
72+
}
73+
74+
case 'setRoomTemperature': {
75+
const oldTemperature =
76+
homeTemperatures[
77+
parameters.room as keyof typeof homeTemperatures
78+
];
79+
80+
homeTemperatures[
81+
parameters.room as keyof typeof homeTemperatures
82+
] = parameters.temperature;
83+
84+
sendDataMessage({
85+
role: 'data',
86+
data: {
87+
oldTemperature,
88+
newTemperature: parameters.temperature,
89+
description: `Temperature in ${parameters.room} changed from ${oldTemperature} to ${parameters.temperature}`,
90+
},
91+
});
92+
93+
return {
94+
tool_call_id: toolCall.id,
95+
output: `temperature set successfully`,
96+
};
97+
}
98+
99+
default:
100+
throw new Error(
101+
`Unknown tool call function: ${toolCall.function.name}`,
102+
);
103+
}
104+
},
105+
);
106+
107+
runResult = await forwardStream(
108+
openai.beta.threads.runs.submitToolOutputsStream(
109+
threadId,
110+
runResult.id,
111+
{ tool_outputs },
112+
),
113+
);
151114
}
152115
},
153116
);

‎examples/next-openai/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"dependencies": {
1212
"ai": "3.0.10",
1313
"next": "14.1.1",
14-
"openai": "4.16.1",
14+
"openai": "4.29.0",
1515
"react": "18.2.0",
1616
"react-dom": "^18.2.0"
1717
},

‎packages/core/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
"jsdom": "^23.0.0",
110110
"langchain": "0.0.196",
111111
"msw": "2.0.9",
112-
"openai": "4.28.4",
112+
"openai": "4.29.0",
113113
"react-dom": "^18.2.0",
114114
"react-server-dom-webpack": "18.3.0-canary-eb33bd747-20240312",
115115
"solid-js": "^1.8.7",

‎packages/core/react/use-assistant.ts

+19-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import { useState } from 'react';
44
import { readDataStream } from '../shared/read-data-stream';
55
import { Message } from '../shared/types';
6+
import { nanoid } from 'nanoid';
67

78
export type AssistantStatus = 'in_progress' | 'awaiting_message';
89

@@ -172,11 +173,28 @@ export function experimental_useAssistant({
172173
break;
173174
}
174175

176+
case 'text': {
177+
// text delta - add to last message:
178+
setMessages(messages => {
179+
const lastMessage = messages[messages.length - 1];
180+
return [
181+
...messages.slice(0, messages.length - 1),
182+
{
183+
id: lastMessage.id,
184+
role: lastMessage.role,
185+
content: lastMessage.content + value,
186+
},
187+
];
188+
});
189+
190+
break;
191+
}
192+
175193
case 'data_message': {
176194
setMessages(messages => [
177195
...messages,
178196
{
179-
id: value.id ?? '',
197+
id: value.id ?? nanoid(),
180198
role: 'data',
181199
content: '',
182200
data: value.data,

‎packages/core/streams/assistant-response.ts

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { AssistantStream } from 'openai/lib/AssistantStream';
12
import { formatStreamPart } from '../shared/stream-parts';
23
import { AssistantMessage, DataMessage } from '../shared/types';
34

@@ -6,11 +7,12 @@ type AssistantResponseSettings = {
67
messageId: string;
78
};
89

9-
type AssistantResponseCallback = (stream: {
10+
type AssistantResponseCallback = (options: {
1011
threadId: string;
1112
messageId: string;
1213
sendMessage: (message: AssistantMessage) => void;
1314
sendDataMessage: (message: DataMessage) => void;
15+
forwardStream: (stream: AssistantStream) => Promise<any>;
1416
}) => Promise<void>;
1517

1618
export function experimental_AssistantResponse(
@@ -39,6 +41,49 @@ export function experimental_AssistantResponse(
3941
);
4042
};
4143

44+
const forwardStream = async (stream: AssistantStream) => {
45+
let result: any = undefined;
46+
47+
for await (const value of stream) {
48+
switch (value.event) {
49+
case 'thread.message.created': {
50+
controller.enqueue(
51+
textEncoder.encode(
52+
formatStreamPart('assistant_message', {
53+
id: value.data.id,
54+
role: 'assistant',
55+
content: [{ type: 'text', text: { value: '' } }],
56+
}),
57+
),
58+
);
59+
break;
60+
}
61+
62+
case 'thread.message.delta': {
63+
const content = value.data.delta.content?.[0];
64+
65+
if (content?.type === 'text' && content.text?.value != null) {
66+
controller.enqueue(
67+
textEncoder.encode(
68+
formatStreamPart('text', content.text.value),
69+
),
70+
);
71+
}
72+
73+
break;
74+
}
75+
76+
case 'thread.run.completed':
77+
case 'thread.run.requires_action': {
78+
result = value.data;
79+
break;
80+
}
81+
}
82+
}
83+
84+
return result;
85+
};
86+
4287
// send the threadId and messageId as the first message:
4388
controller.enqueue(
4489
textEncoder.encode(
@@ -55,6 +100,7 @@ export function experimental_AssistantResponse(
55100
messageId,
56101
sendMessage,
57102
sendDataMessage,
103+
forwardStream,
58104
});
59105
} catch (error) {
60106
sendError((error as any).message ?? `${error}`);

‎pnpm-lock.yaml

+348-11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.