Skip to content

Commit d158a47

Browse files
authoredMar 3, 2024
Fix potential race condition of renderer (#1072)
1 parent d6e933d commit d158a47

File tree

6 files changed

+378
-17
lines changed

6 files changed

+378
-17
lines changed
 

‎.changeset/hot-birds-develop.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix potential race conditions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
2+
3+
exports[`rsc - render() > should emit React Nodes with async render function 1`] = `
4+
{
5+
"children": {
6+
"children": {},
7+
"props": {
8+
"current": undefined,
9+
"next": {
10+
"done": false,
11+
"next": {
12+
"done": true,
13+
"value": <div>
14+
Weather
15+
</div>,
16+
},
17+
"value": <div>
18+
Weather
19+
</div>,
20+
},
21+
},
22+
"type": "Row",
23+
},
24+
"props": {
25+
"fallback": undefined,
26+
},
27+
"type": "Symbol(react.suspense)",
28+
}
29+
`;
30+
31+
exports[`rsc - render() > should emit React Nodes with generator render function 1`] = `
32+
{
33+
"children": {
34+
"children": {},
35+
"props": {
36+
"current": undefined,
37+
"next": {
38+
"done": false,
39+
"next": {
40+
"done": false,
41+
"next": {
42+
"done": true,
43+
"value": <div>
44+
Weather
45+
</div>,
46+
},
47+
"value": <div>
48+
Weather
49+
</div>,
50+
},
51+
"value": <div>
52+
Loading...
53+
</div>,
54+
},
55+
},
56+
"type": "Row",
57+
},
58+
"props": {
59+
"fallback": undefined,
60+
},
61+
"type": "Symbol(react.suspense)",
62+
}
63+
`;
64+
65+
exports[`rsc - render() > should emit React Nodes with sync render function 1`] = `
66+
{
67+
"children": {
68+
"children": {},
69+
"props": {
70+
"current": undefined,
71+
"next": {
72+
"done": false,
73+
"next": {
74+
"done": true,
75+
"value": <div>
76+
Weather
77+
</div>,
78+
},
79+
"value": <div>
80+
Weather
81+
</div>,
82+
},
83+
},
84+
"type": "Row",
85+
},
86+
"props": {
87+
"fallback": undefined,
88+
},
89+
"type": "Symbol(react.suspense)",
90+
}
91+
`;
92+
93+
exports[`rsc - streamable > should emit React Nodes with async render function 1`] = `
94+
{
95+
"children": {
96+
"children": {},
97+
"props": {
98+
"current": undefined,
99+
"next": {
100+
"done": false,
101+
"next": {
102+
"done": true,
103+
"value": <div>
104+
Weather
105+
</div>,
106+
},
107+
"value": <div>
108+
Weather
109+
</div>,
110+
},
111+
},
112+
"type": "Row",
113+
},
114+
"props": {
115+
"fallback": undefined,
116+
},
117+
"type": "Symbol(react.suspense)",
118+
}
119+
`;
120+
121+
exports[`rsc - streamable > should emit React Nodes with generator render function 1`] = `
122+
{
123+
"children": {
124+
"children": {},
125+
"props": {
126+
"current": undefined,
127+
"next": {
128+
"done": false,
129+
"next": {
130+
"done": false,
131+
"next": {
132+
"done": true,
133+
"value": <div>
134+
Weather
135+
</div>,
136+
},
137+
"value": <div>
138+
Weather
139+
</div>,
140+
},
141+
"value": <div>
142+
Loading...
143+
</div>,
144+
},
145+
},
146+
"type": "Row",
147+
},
148+
"props": {
149+
"fallback": undefined,
150+
},
151+
"type": "Symbol(react.suspense)",
152+
}
153+
`;
154+
155+
exports[`rsc - streamable > should emit React Nodes with sync render function 1`] = `
156+
{
157+
"children": {
158+
"children": {},
159+
"props": {
160+
"current": undefined,
161+
"next": {
162+
"done": false,
163+
"next": {
164+
"done": true,
165+
"value": <div>
166+
Weather
167+
</div>,
168+
},
169+
"value": <div>
170+
Weather
171+
</div>,
172+
},
173+
},
174+
"type": "Row",
175+
},
176+
"props": {
177+
"fallback": undefined,
178+
},
179+
"type": "Symbol(react.suspense)",
180+
}
181+
`;

‎packages/core/rsc/streamable.tsx

+24-13
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ export function render<
197197
}): ReactNode {
198198
const ui = createStreamableUI(options.initial);
199199

200+
// The default text renderer just returns the content as string.
201+
const text = options.text
202+
? options.text
203+
: ({ content }: { content: string }) => content;
204+
200205
const functions = options.functions
201206
? Object.entries(options.functions).map(
202207
([name, { description, parameters }]) => {
@@ -233,7 +238,7 @@ export function render<
233238
);
234239
}
235240

236-
let finished: ReturnType<typeof createResolvablePromise> | undefined;
241+
let finished: Promise<void> | undefined;
237242

238243
async function handleRender(
239244
args: any,
@@ -242,8 +247,14 @@ export function render<
242247
) {
243248
if (!renderer) return;
244249

245-
if (finished) await finished.promise;
246-
finished = createResolvablePromise();
250+
const resolvable = createResolvablePromise<void>();
251+
252+
if (finished) {
253+
finished = finished.then(() => resolvable.promise);
254+
} else {
255+
finished = resolvable.promise;
256+
}
257+
247258
const value = renderer(args);
248259
if (
249260
value instanceof Promise ||
@@ -254,7 +265,7 @@ export function render<
254265
) {
255266
const node = await (value as Promise<React.ReactNode>);
256267
res.update(node);
257-
finished?.resolve(void 0);
268+
resolvable.resolve(void 0);
258269
} else if (
259270
value &&
260271
typeof value === 'object' &&
@@ -270,24 +281,24 @@ export function render<
270281
res.update(value);
271282
if (done) break;
272283
}
273-
finished?.resolve(void 0);
284+
resolvable.resolve(void 0);
274285
} else if (value && typeof value === 'object' && Symbol.iterator in value) {
275286
const it = value as Generator<React.ReactNode, React.ReactNode, void>;
276287
while (true) {
277288
const { done, value } = it.next();
278289
res.update(value);
279290
if (done) break;
280291
}
281-
finished?.resolve(void 0);
292+
resolvable.resolve(void 0);
282293
} else {
283294
res.update(value);
284-
finished?.resolve(void 0);
295+
resolvable.resolve(void 0);
285296
}
286297
}
287298

288299
(async () => {
289300
let hasFunction = false;
290-
let text = '';
301+
let content = '';
291302

292303
const parseFunctionCallArguments = (fn: {
293304
type: 'functions' | 'tools';
@@ -365,18 +376,18 @@ export function render<
365376
}
366377
: {}),
367378
onText(chunk) {
368-
text += chunk;
369-
handleRender({ content: text, done: false }, options.text, ui);
379+
content += chunk;
380+
handleRender({ content, done: false }, text, ui);
370381
},
371382
async onFinal() {
372383
if (hasFunction) {
373-
await finished?.promise;
384+
await finished;
374385
ui.done();
375386
return;
376387
}
377388

378-
handleRender({ content: text, done: true }, options.text, ui);
379-
await finished?.promise;
389+
handleRender({ content, done: true }, text, ui);
390+
await finished;
380391
ui.done();
381392
},
382393
},
+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import {
2+
openaiChatCompletionChunks,
3+
openaiFunctionCallChunks,
4+
} from '../tests/snapshots/openai-chat';
5+
import { DEFAULT_TEST_URL, createMockServer } from '../tests/utils/mock-server';
6+
import { render } from './streamable';
7+
import { z } from 'zod';
8+
9+
const FUNCTION_CALL_TEST_URL = DEFAULT_TEST_URL + 'mock-func-call';
10+
11+
const server = createMockServer([
12+
{
13+
url: DEFAULT_TEST_URL,
14+
chunks: openaiChatCompletionChunks,
15+
formatChunk: chunk => `data: ${JSON.stringify(chunk)}\n\n`,
16+
suffix: 'data: [DONE]',
17+
},
18+
{
19+
url: FUNCTION_CALL_TEST_URL,
20+
chunks: openaiFunctionCallChunks,
21+
formatChunk: chunk => `data: ${JSON.stringify(chunk)}\n\n`,
22+
suffix: 'data: [DONE]',
23+
},
24+
]);
25+
26+
beforeAll(() => {
27+
server.listen();
28+
});
29+
30+
afterEach(() => {
31+
server.resetHandlers();
32+
});
33+
34+
afterAll(() => {
35+
server.close();
36+
});
37+
38+
async function recursiveResolve(val: any): Promise<any> {
39+
if (val && typeof val === 'object' && typeof val.then === 'function') {
40+
return await recursiveResolve(await val);
41+
}
42+
43+
if (Array.isArray(val)) {
44+
return await Promise.all(val.map(recursiveResolve));
45+
}
46+
47+
if (val && typeof val === 'object') {
48+
const result: any = {};
49+
for (const key in val) {
50+
result[key] = await recursiveResolve(val[key]);
51+
}
52+
return result;
53+
}
54+
55+
return val;
56+
}
57+
58+
async function simulateFlightServerRender(node: React.ReactElement) {
59+
async function traverse(node: any): Promise<any> {
60+
if (!node) return {};
61+
62+
// Let's only do one level of promise resolution here. As it's only for testing purposes.
63+
const props = await recursiveResolve({ ...node.props } || {});
64+
65+
const { type } = node;
66+
const { children, ...otherProps } = props;
67+
const typeName = typeof type === 'function' ? type.name : String(type);
68+
69+
return {
70+
type: typeName,
71+
props: otherProps,
72+
children:
73+
typeof children === 'string'
74+
? children
75+
: Array.isArray(children)
76+
? children.map(traverse)
77+
: await traverse(children),
78+
};
79+
}
80+
81+
return traverse(node);
82+
}
83+
84+
function createMockUpProvider() {
85+
return {
86+
chat: {
87+
completions: {
88+
create: async () => {
89+
return await fetch(FUNCTION_CALL_TEST_URL);
90+
},
91+
},
92+
},
93+
} as any;
94+
}
95+
96+
describe('rsc - render()', () => {
97+
it('should emit React Nodes with sync render function', async () => {
98+
const ui = render({
99+
model: 'gpt-3.5-turbo',
100+
messages: [],
101+
provider: createMockUpProvider(),
102+
functions: {
103+
get_current_weather: {
104+
description: 'Get the current weather',
105+
parameters: z.object({}),
106+
render: () => {
107+
return <div>Weather</div>;
108+
},
109+
},
110+
},
111+
});
112+
113+
const rendered = await simulateFlightServerRender(ui as any);
114+
expect(rendered).toMatchSnapshot();
115+
});
116+
117+
it('should emit React Nodes with async render function', async () => {
118+
const ui = render({
119+
model: 'gpt-3.5-turbo',
120+
messages: [],
121+
provider: createMockUpProvider(),
122+
functions: {
123+
get_current_weather: {
124+
description: 'Get the current weather',
125+
parameters: z.object({}),
126+
render: async () => {
127+
await new Promise(resolve => setTimeout(resolve, 100));
128+
return <div>Weather</div>;
129+
},
130+
},
131+
},
132+
});
133+
134+
const rendered = await simulateFlightServerRender(ui as any);
135+
expect(rendered).toMatchSnapshot();
136+
});
137+
138+
it('should emit React Nodes with generator render function', async () => {
139+
const ui = render({
140+
model: 'gpt-3.5-turbo',
141+
messages: [],
142+
provider: createMockUpProvider(),
143+
functions: {
144+
get_current_weather: {
145+
description: 'Get the current weather',
146+
parameters: z.object({}),
147+
render: async function* () {
148+
yield <div>Loading...</div>;
149+
await new Promise(resolve => setTimeout(resolve, 100));
150+
return <div>Weather</div>;
151+
},
152+
},
153+
},
154+
});
155+
156+
const rendered = await simulateFlightServerRender(ui as any);
157+
expect(rendered).toMatchSnapshot();
158+
});
159+
});

‎packages/core/rsc/utils.tsx

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import React, { Suspense } from 'react';
22

3-
export function createResolvablePromise() {
4-
let resolve: (value: any) => void, reject: (error: any) => void;
5-
const promise = new Promise((res, rej) => {
3+
export function createResolvablePromise<T = any>() {
4+
let resolve: (value: T) => void, reject: (error: unknown) => void;
5+
const promise = new Promise<T>((res, rej) => {
66
resolve = res;
77
reject = rej;
88
});

‎packages/core/vitest.ui.react.config.js

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ export default defineConfig({
77
test: {
88
environment: 'jsdom',
99
globals: true,
10-
include: ['react/**/*.ui.test.ts', 'react/**/*.ui.test.tsx'],
10+
include: [
11+
'react/**/*.ui.test.ts',
12+
'react/**/*.ui.test.tsx',
13+
'rsc/**/*.ui.test.ts',
14+
'rsc/**/*.ui.test.tsx',
15+
],
1116
},
1217
});

0 commit comments

Comments
 (0)
Please sign in to comment.