Skip to content

Commit b9a831e

Browse files
authoredApr 26, 2024··
New experimental_streamUI() API (#1379)
1 parent 3e304eb commit b9a831e

File tree

8 files changed

+683
-0
lines changed

8 files changed

+683
-0
lines changed
 

‎.changeset/wise-cycles-mix.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
ai/rsc: add experimental_streamUI()

‎packages/core/rsc/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export type {
33
getMutableAIState,
44
createStreamableUI,
55
createStreamableValue,
6+
experimental_streamUI,
67
render,
78
createAI,
89
} from './rsc-server';

‎packages/core/rsc/rsc-server.ts

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ export {
44
createStreamableValue,
55
render,
66
} from './streamable';
7+
export { experimental_streamUI } from './stream-ui';
78
export { createAI } from './provider';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
2+
3+
exports[`result.value > should render text 1`] = `
4+
{
5+
"children": {
6+
"children": {},
7+
"props": {
8+
"c": undefined,
9+
"n": {
10+
"done": false,
11+
"next": {
12+
"done": false,
13+
"next": {
14+
"done": false,
15+
"next": {
16+
"done": false,
17+
"next": {
18+
"done": false,
19+
"next": {
20+
"done": false,
21+
"next": {
22+
"done": true,
23+
"value": "{ \\"content\\": \\"Hello, world!\\" }",
24+
},
25+
"value": "{ \\"content\\": \\"Hello, world!\\" }",
26+
},
27+
"value": "{ \\"content\\": \\"Hello, world!\\"",
28+
},
29+
"value": "{ \\"content\\": \\"Hello, world",
30+
},
31+
"value": "{ \\"content\\": \\"Hello, ",
32+
},
33+
"value": "{ \\"content\\": ",
34+
},
35+
"value": "{ ",
36+
},
37+
},
38+
"type": "",
39+
},
40+
"props": {
41+
"fallback": undefined,
42+
},
43+
"type": "Symbol(react.suspense)",
44+
}
45+
`;
46+
47+
exports[`result.value > should render text function returned ui 1`] = `
48+
{
49+
"children": {
50+
"children": {},
51+
"props": {
52+
"c": undefined,
53+
"n": {
54+
"done": false,
55+
"next": {
56+
"done": false,
57+
"next": {
58+
"done": false,
59+
"next": {
60+
"done": false,
61+
"next": {
62+
"done": false,
63+
"next": {
64+
"done": false,
65+
"next": {
66+
"done": false,
67+
"next": {
68+
"done": true,
69+
"value": <h1>
70+
{ "content": "Hello, world!" }
71+
</h1>,
72+
},
73+
"value": <h1>
74+
{ "content": "Hello, world!" }
75+
</h1>,
76+
},
77+
"value": <h1>
78+
{ "content": "Hello, world!" }
79+
</h1>,
80+
},
81+
"value": <h1>
82+
{ "content": "Hello, world!"
83+
</h1>,
84+
},
85+
"value": <h1>
86+
{ "content": "Hello, world
87+
</h1>,
88+
},
89+
"value": <h1>
90+
{ "content": "Hello,
91+
</h1>,
92+
},
93+
"value": <h1>
94+
{ "content":
95+
</h1>,
96+
},
97+
"value": <h1>
98+
{
99+
</h1>,
100+
},
101+
},
102+
"type": "",
103+
},
104+
"props": {
105+
"fallback": undefined,
106+
},
107+
"type": "Symbol(react.suspense)",
108+
}
109+
`;
110+
111+
exports[`result.value > should render tool call results 1`] = `
112+
{
113+
"children": {
114+
"children": {},
115+
"props": {
116+
"c": undefined,
117+
"n": {
118+
"done": false,
119+
"next": {
120+
"done": false,
121+
"next": {
122+
"done": true,
123+
"value": <div>
124+
tool1:
125+
value
126+
</div>,
127+
},
128+
"value": <div>
129+
tool1:
130+
value
131+
</div>,
132+
},
133+
"value": "",
134+
},
135+
},
136+
"type": "",
137+
},
138+
"props": {
139+
"fallback": undefined,
140+
},
141+
"type": "Symbol(react.suspense)",
142+
}
143+
`;
144+
145+
exports[`result.value > should render tool call results with generator render function 1`] = `
146+
{
147+
"children": {
148+
"children": {},
149+
"props": {
150+
"c": undefined,
151+
"n": {
152+
"done": false,
153+
"next": {
154+
"done": false,
155+
"next": {
156+
"done": false,
157+
"next": {
158+
"done": true,
159+
"value": <div>
160+
tool:
161+
value
162+
</div>,
163+
},
164+
"value": <div>
165+
tool:
166+
value
167+
</div>,
168+
},
169+
"value": "",
170+
},
171+
"value": <div>
172+
Loading...
173+
</div>,
174+
},
175+
},
176+
"type": "",
177+
},
178+
"props": {
179+
"fallback": undefined,
180+
},
181+
"type": "Symbol(react.suspense)",
182+
}
183+
`;
184+
185+
exports[`result.value > should show better error messages if legacy options are passed 1`] = `[Error: Tool definition in \`experimental_streamUI\` should not have \`render\` property. Use \`generate\` instead. Found in tool: tool1]`;

‎packages/core/rsc/stream-ui/index.tsx

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export { experimental_streamUI } from './stream-ui';
+317
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
import {
2+
InvalidToolArgumentsError,
3+
LanguageModelV1,
4+
NoSuchToolError,
5+
} from '@ai-sdk/provider';
6+
import { ReactNode } from 'react';
7+
import { z } from 'zod';
8+
9+
import { CallSettings } from '../../core/prompt/call-settings';
10+
import { Prompt } from '../../core/prompt/prompt';
11+
import { createStreamableUI } from '../streamable';
12+
import { retryWithExponentialBackoff } from '../../core/util/retry-with-exponential-backoff';
13+
import { getValidatedPrompt } from '../../core/prompt/get-validated-prompt';
14+
import { convertZodToJSONSchema } from '../../core/util/convert-zod-to-json-schema';
15+
import { prepareCallSettings } from '../../core/prompt/prepare-call-settings';
16+
import { convertToLanguageModelPrompt } from '../../core/prompt/convert-to-language-model-prompt';
17+
import { createResolvablePromise } from '../utils';
18+
import { safeParseJSON } from '@ai-sdk/provider-utils';
19+
20+
type Streamable = ReactNode | Promise<ReactNode>;
21+
22+
type Renderer<T extends Array<any>> = (
23+
...args: T
24+
) =>
25+
| Streamable
26+
| Generator<Streamable, Streamable, void>
27+
| AsyncGenerator<Streamable, Streamable, void>;
28+
29+
type RenderTool<PARAMETERS extends z.ZodTypeAny = any> = {
30+
description?: string;
31+
parameters: PARAMETERS;
32+
generate?: Renderer<
33+
[
34+
z.infer<PARAMETERS>,
35+
{
36+
toolName: string;
37+
toolCallId: string;
38+
},
39+
]
40+
>;
41+
};
42+
43+
type RenderText = Renderer<
44+
[
45+
{
46+
/**
47+
* The full text content from the model so far.
48+
*/
49+
content: string;
50+
/**
51+
* The new appended text content from the model since the last `text` call.
52+
*/
53+
delta: string;
54+
/**
55+
* Whether the model is done generating text.
56+
* If `true`, the `content` will be the final output and this call will be the last.
57+
*/
58+
done: boolean;
59+
},
60+
]
61+
>;
62+
63+
type RenderResult = {
64+
value: ReactNode;
65+
} & Awaited<ReturnType<LanguageModelV1['doStream']>>;
66+
67+
const defaultTextRenderer: RenderText = ({ content }: { content: string }) =>
68+
content;
69+
70+
/**
71+
* `experimental_streamUI` is a helper function to create a streamable UI from LLMs.
72+
*/
73+
export async function experimental_streamUI<
74+
TOOLS extends Record<string, RenderTool>,
75+
>({
76+
model,
77+
tools,
78+
system,
79+
prompt,
80+
messages,
81+
maxRetries,
82+
abortSignal,
83+
initial,
84+
text,
85+
...settings
86+
}: CallSettings &
87+
Prompt & {
88+
/**
89+
* The language model to use.
90+
*/
91+
model: LanguageModelV1;
92+
93+
/**
94+
* The tools that the model can call. The model needs to support calling tools.
95+
*/
96+
tools?: TOOLS;
97+
98+
text?: RenderText;
99+
initial?: ReactNode;
100+
}): Promise<RenderResult> {
101+
// TODO: Remove these errors after the experimental phase.
102+
if (typeof model === 'string') {
103+
throw new Error(
104+
'`model` cannot be a string in `experimental_streamUI`. Use the actual model instance instead.',
105+
);
106+
}
107+
if ('functions' in settings) {
108+
throw new Error(
109+
'`functions` is not supported in `experimental_streamUI`, use `tools` instead.',
110+
);
111+
}
112+
if ('provider' in settings) {
113+
throw new Error(
114+
'`provider` is no longer needed in `experimental_streamUI`. Use `model` instead.',
115+
);
116+
}
117+
if (tools) {
118+
for (const [name, tool] of Object.entries(tools)) {
119+
if ('render' in tool) {
120+
throw new Error(
121+
'Tool definition in `experimental_streamUI` should not have `render` property. Use `generate` instead. Found in tool: ' +
122+
name,
123+
);
124+
}
125+
}
126+
}
127+
128+
const ui = createStreamableUI(initial);
129+
130+
// The default text renderer just returns the content as string.
131+
const textRender = text || defaultTextRenderer;
132+
133+
let finished: Promise<void> | undefined;
134+
135+
async function handleRender(
136+
args: [payload: any] | [payload: any, options: any],
137+
renderer: undefined | Renderer<any>,
138+
res: ReturnType<typeof createStreamableUI>,
139+
) {
140+
if (!renderer) return;
141+
142+
const resolvable = createResolvablePromise<void>();
143+
144+
if (finished) {
145+
finished = finished.then(() => resolvable.promise);
146+
} else {
147+
finished = resolvable.promise;
148+
}
149+
150+
const value = renderer(...args);
151+
if (
152+
value instanceof Promise ||
153+
(value &&
154+
typeof value === 'object' &&
155+
'then' in value &&
156+
typeof value.then === 'function')
157+
) {
158+
const node = await (value as Promise<React.ReactNode>);
159+
res.update(node);
160+
resolvable.resolve(void 0);
161+
} else if (
162+
value &&
163+
typeof value === 'object' &&
164+
Symbol.asyncIterator in value
165+
) {
166+
const it = value as AsyncGenerator<
167+
React.ReactNode,
168+
React.ReactNode,
169+
void
170+
>;
171+
while (true) {
172+
const { done, value } = await it.next();
173+
res.update(value);
174+
if (done) break;
175+
}
176+
resolvable.resolve(void 0);
177+
} else if (value && typeof value === 'object' && Symbol.iterator in value) {
178+
const it = value as Generator<React.ReactNode, React.ReactNode, void>;
179+
while (true) {
180+
const { done, value } = it.next();
181+
res.update(value);
182+
if (done) break;
183+
}
184+
resolvable.resolve(void 0);
185+
} else {
186+
res.update(value);
187+
resolvable.resolve(void 0);
188+
}
189+
}
190+
191+
const retry = retryWithExponentialBackoff({ maxRetries });
192+
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
193+
const result = await retry(() =>
194+
model.doStream({
195+
mode: {
196+
type: 'regular',
197+
tools:
198+
tools == null
199+
? undefined
200+
: Object.entries(tools).map(([name, tool]) => ({
201+
type: 'function',
202+
name,
203+
description: tool.description,
204+
parameters: convertZodToJSONSchema(tool.parameters),
205+
})),
206+
},
207+
...prepareCallSettings(settings),
208+
inputFormat: validatedPrompt.type,
209+
prompt: convertToLanguageModelPrompt(validatedPrompt),
210+
abortSignal,
211+
}),
212+
);
213+
214+
const [stream, forkedStream] = result.stream.tee();
215+
216+
(async () => {
217+
try {
218+
// Consume the forked stream asynchonously.
219+
220+
let content = '';
221+
let hasToolCall = false;
222+
223+
const reader = forkedStream.getReader();
224+
while (true) {
225+
const { done, value } = await reader.read();
226+
if (done) break;
227+
228+
switch (value.type) {
229+
case 'text-delta': {
230+
content += value.textDelta;
231+
handleRender(
232+
[{ content, done: false, delta: value.textDelta }],
233+
textRender,
234+
ui,
235+
);
236+
break;
237+
}
238+
239+
case 'tool-call-delta': {
240+
hasToolCall = true;
241+
break;
242+
}
243+
244+
case 'tool-call': {
245+
const toolName = value.toolName as keyof TOOLS & string;
246+
247+
if (!tools) {
248+
throw new NoSuchToolError({ toolName: toolName });
249+
}
250+
251+
const tool = tools[toolName];
252+
if (!tool) {
253+
throw new NoSuchToolError({
254+
toolName,
255+
availableTools: Object.keys(tools),
256+
});
257+
}
258+
259+
const parseResult = safeParseJSON({
260+
text: value.args,
261+
schema: tool.parameters,
262+
});
263+
264+
if (parseResult.success === false) {
265+
throw new InvalidToolArgumentsError({
266+
toolName,
267+
toolArgs: value.args,
268+
cause: parseResult.error,
269+
});
270+
}
271+
272+
handleRender(
273+
[
274+
parseResult.value,
275+
{
276+
toolName,
277+
toolCallId: value.toolCallId,
278+
},
279+
],
280+
tool.generate,
281+
ui,
282+
);
283+
284+
break;
285+
}
286+
287+
case 'error': {
288+
throw value.error;
289+
}
290+
291+
case 'finish': {
292+
// Nothing to do here.
293+
}
294+
}
295+
}
296+
297+
if (hasToolCall) {
298+
await finished;
299+
ui.done();
300+
} else {
301+
handleRender([{ content, done: true }], textRender, ui);
302+
await finished;
303+
ui.done();
304+
}
305+
} catch (error) {
306+
// During the stream rendering, we don't want to throw the error to the
307+
// parent scope but only let the React's error boundary to catch it.
308+
ui.error(error);
309+
}
310+
})();
311+
312+
return {
313+
...result,
314+
stream,
315+
value: ui.value,
316+
};
317+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import { convertArrayToReadableStream } from '../../core/test/convert-array-to-readable-stream';
2+
import { MockLanguageModelV1 } from '../../core/test/mock-language-model-v1';
3+
import { experimental_streamUI } from './stream-ui';
4+
import { z } from 'zod';
5+
6+
async function recursiveResolve(val: any): Promise<any> {
7+
if (val && typeof val === 'object' && typeof val.then === 'function') {
8+
return await recursiveResolve(await val);
9+
}
10+
11+
if (Array.isArray(val)) {
12+
return await Promise.all(val.map(recursiveResolve));
13+
}
14+
15+
if (val && typeof val === 'object') {
16+
const result: any = {};
17+
for (const key in val) {
18+
result[key] = await recursiveResolve(val[key]);
19+
}
20+
return result;
21+
}
22+
23+
return val;
24+
}
25+
26+
async function simulateFlightServerRender(node: React.ReactNode) {
27+
async function traverse(node: any): Promise<any> {
28+
if (!node) return {};
29+
30+
// Let's only do one level of promise resolution here. As it's only for testing purposes.
31+
const props = await recursiveResolve({ ...node.props } || {});
32+
33+
const { type } = node;
34+
const { children, ...otherProps } = props;
35+
const typeName = typeof type === 'function' ? type.name : String(type);
36+
37+
return {
38+
type: typeName,
39+
props: otherProps,
40+
children:
41+
typeof children === 'string'
42+
? children
43+
: Array.isArray(children)
44+
? children.map(traverse)
45+
: await traverse(children),
46+
};
47+
}
48+
49+
return traverse(node);
50+
}
51+
52+
const mockTextModel = new MockLanguageModelV1({
53+
doStream: async () => {
54+
return {
55+
stream: convertArrayToReadableStream([
56+
{ type: 'text-delta', textDelta: '{ ' },
57+
{ type: 'text-delta', textDelta: '"content": ' },
58+
{ type: 'text-delta', textDelta: `"Hello, ` },
59+
{ type: 'text-delta', textDelta: `world` },
60+
{ type: 'text-delta', textDelta: `!"` },
61+
{ type: 'text-delta', textDelta: ' }' },
62+
]),
63+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
64+
};
65+
},
66+
});
67+
68+
const mockToolModel = new MockLanguageModelV1({
69+
doStream: async () => {
70+
return {
71+
stream: convertArrayToReadableStream([
72+
{
73+
type: 'tool-call',
74+
toolCallType: 'function',
75+
toolCallId: 'call-1',
76+
toolName: 'tool1',
77+
args: `{ "value": "value" }`,
78+
},
79+
]),
80+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
81+
};
82+
},
83+
});
84+
85+
describe('result.value', () => {
86+
it('should render text', async () => {
87+
const result = await experimental_streamUI({
88+
model: mockTextModel,
89+
prompt: '',
90+
});
91+
92+
const rendered = await simulateFlightServerRender(result.value);
93+
expect(rendered).toMatchSnapshot();
94+
});
95+
96+
it('should render text function returned ui', async () => {
97+
const result = await experimental_streamUI({
98+
model: mockTextModel,
99+
prompt: '',
100+
text: ({ content }) => <h1>{content}</h1>,
101+
});
102+
103+
const rendered = await simulateFlightServerRender(result.value);
104+
expect(rendered).toMatchSnapshot();
105+
});
106+
107+
it('should render tool call results', async () => {
108+
const result = await experimental_streamUI({
109+
model: mockToolModel,
110+
prompt: '',
111+
tools: {
112+
tool1: {
113+
description: 'test tool 1',
114+
parameters: z.object({
115+
value: z.string(),
116+
}),
117+
generate: async ({ value }) => {
118+
await new Promise(resolve => setTimeout(resolve, 100));
119+
return <div>tool1: {value}</div>;
120+
},
121+
},
122+
},
123+
});
124+
125+
const rendered = await simulateFlightServerRender(result.value);
126+
expect(rendered).toMatchSnapshot();
127+
});
128+
129+
it('should render tool call results with generator render function', async () => {
130+
const result = await experimental_streamUI({
131+
model: mockToolModel,
132+
prompt: '',
133+
tools: {
134+
tool1: {
135+
description: 'test tool 1',
136+
parameters: z.object({
137+
value: z.string(),
138+
}),
139+
generate: async function* ({ value }) {
140+
yield <div>Loading...</div>;
141+
await new Promise(resolve => setTimeout(resolve, 100));
142+
return <div>tool: {value}</div>;
143+
},
144+
},
145+
},
146+
});
147+
148+
const rendered = await simulateFlightServerRender(result.value);
149+
expect(rendered).toMatchSnapshot();
150+
});
151+
152+
it('should show better error messages if legacy options are passed', async () => {
153+
try {
154+
await experimental_streamUI({
155+
model: mockToolModel,
156+
prompt: '',
157+
tools: {
158+
tool1: {
159+
description: 'test tool 1',
160+
parameters: z.object({
161+
value: z.string(),
162+
}),
163+
render: async function* () {},
164+
},
165+
},
166+
});
167+
} catch (e) {
168+
expect(e).toMatchSnapshot();
169+
}
170+
});
171+
});

‎packages/core/rsc/streamable.tsx

+2
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ type Renderer<T> = (
278278
/**
279279
* `render` is a helper function to create a streamable UI from some LLMs.
280280
* Currently, it only supports OpenAI's GPT models with Function Calling and Assistants Tools.
281+
*
282+
* @deprecated It's recommended to use the `experimental_streamUI` API for compatibility with the new core APIs.
281283
*/
282284
export function render<
283285
TS extends {

0 commit comments

Comments
 (0)
Please sign in to comment.