Skip to content

Commit 3cabf07

Browse files
shudinglgrammel
andauthoredJun 11, 2024··
fix(ai/rsc): New createStreamableUI implementation (#1825)
Co-authored-by: Lars Grammel <lars.grammel@gmail.com>
1 parent 8e5ea89 commit 3cabf07

24 files changed

+448
-1154
lines changed
 

‎.changeset/early-hounds-flow.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix(ai/rsc): Refactor streamable UI internal implementation

‎packages/core/rsc/__snapshots__/streamable.ui.test.tsx.snap

-91
This file was deleted.

‎packages/core/rsc/rsc-shared.mts

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ export {
88
useActions,
99
useSyncUIState,
1010
InternalAIProvider,
11+
InternalStreamableUIClient,
1112
} from './shared-client';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
'use client';
2+
3+
import { useState, useEffect } from 'react';
4+
import { StreamableValue } from '../types';
5+
import { readStreamableValue } from './streamable';
6+
7+
export function InternalStreamableUIClient<T>({
8+
s,
9+
}: {
10+
s: StreamableValue<T>;
11+
}) {
12+
// Set the value to the initial value of the streamable, if it has one.
13+
const [value, setValue] = useState<T | undefined>(s.curr);
14+
15+
// Error state for the streamable. It might be errored initially and we want
16+
// to error out as soon as possible.
17+
const [error, setError] = useState<Error | undefined>(s.error);
18+
19+
useEffect(() => {
20+
let canceled = false;
21+
setError(undefined);
22+
23+
(async () => {
24+
try {
25+
// Read the streamable value and update the state with the new value.
26+
for await (const v of readStreamableValue(s)) {
27+
if (canceled) {
28+
break;
29+
}
30+
31+
setValue(v);
32+
}
33+
} catch (e) {
34+
if (canceled) {
35+
return;
36+
}
37+
38+
setError(e as Error);
39+
}
40+
})();
41+
42+
return () => {
43+
// If the component is unmounted, we want to cancel the stream.
44+
canceled = true;
45+
};
46+
}, [s]);
47+
48+
// This ensures that errors from the streamable UI are thrown during the
49+
// render phase, so that they can be caught by error boundary components.
50+
// This is necessary for React's declarative model.
51+
if (error) {
52+
throw error;
53+
}
54+
55+
return value;
56+
}

‎packages/core/rsc/shared-client/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ export {
88
useSyncUIState,
99
InternalAIProvider,
1010
} from './context';
11+
export { InternalStreamableUIClient } from './client-wrapper';

‎packages/core/rsc/shared-client/streamable.tsx

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
import { startTransition, useLayoutEffect, useState } from 'react';
1+
import {
2+
ReactElement,
3+
startTransition,
4+
useLayoutEffect,
5+
useState,
6+
} from 'react';
27
import { STREAMABLE_VALUE_TYPE } from '../constants';
38
import type { StreamableValue } from '../types';
49

@@ -97,6 +102,14 @@ export function readStreamableValue<T = unknown>(
97102
(curr as string) = curr + row.diff[1];
98103
}
99104
break;
105+
case 1:
106+
(curr as ReactElement) = (
107+
<>
108+
{curr}
109+
{row.diff[1]}
110+
</>
111+
);
112+
break;
100113
}
101114
} else {
102115
curr = row.curr;

‎packages/core/rsc/stream-ui/__snapshots__/stream-ui.ui.test.tsx.snap

-157
This file was deleted.

‎packages/core/rsc/stream-ui/stream-ui.ui.test.tsx

-172
This file was deleted.

‎packages/core/rsc/streamable.tsx

+40-47
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { ReactNode } from 'react';
1+
import { type ReactNode, isValidElement, ReactElement } from 'react';
22
import type OpenAI from 'openai';
33
import { z } from 'zod';
44
import zodToJsonSchema from 'zod-to-json-schema';
@@ -10,12 +10,9 @@ import {
1010
STREAMABLE_VALUE_TYPE,
1111
DEV_DEFAULT_STREAMABLE_WARNING_TIME,
1212
} from './constants';
13-
import {
14-
createResolvablePromise,
15-
createSuspensedChunk,
16-
consumeStream,
17-
} from './utils';
13+
import { createResolvablePromise, consumeStream } from './utils';
1814
import type { StreamablePatch, StreamableValue } from './types';
15+
import { InternalStreamableUIClient } from './rsc-shared.mjs';
1916

2017
// It's necessary to define the type manually here, otherwise TypeScript compiler
2118
// will not be able to infer the correct return type as it's circular.
@@ -68,9 +65,9 @@ type StreamableUIWrapper = {
6865
* On the client side, it can be rendered as a normal React node.
6966
*/
7067
function createStreamableUI(initialValue?: React.ReactNode) {
71-
let currentValue = initialValue;
68+
const innerStreamable = createStreamableValue<React.ReactNode>(initialValue);
69+
7270
let closed = false;
73-
let { row, resolve, reject } = createSuspensedChunk(initialValue);
7471

7572
function assertStream(method: string) {
7673
if (closed) {
@@ -94,37 +91,19 @@ function createStreamableUI(initialValue?: React.ReactNode) {
9491
warnUnclosedStream();
9592

9693
const streamable: StreamableUIWrapper = {
97-
value: row,
94+
value: <InternalStreamableUIClient s={innerStreamable.value} />,
9895
update(value: React.ReactNode) {
9996
assertStream('.update()');
10097

101-
// There is no need to update the value if it's referentially equal.
102-
if (value === currentValue) {
103-
warnUnclosedStream();
104-
return streamable;
105-
}
106-
107-
const resolvable = createResolvablePromise();
108-
currentValue = value;
109-
110-
resolve({ value: currentValue, done: false, next: resolvable.promise });
111-
resolve = resolvable.resolve;
112-
reject = resolvable.reject;
113-
98+
innerStreamable.update(value);
11499
warnUnclosedStream();
115100

116101
return streamable;
117102
},
118103
append(value: React.ReactNode) {
119104
assertStream('.append()');
120105

121-
const resolvable = createResolvablePromise();
122-
currentValue = value;
123-
124-
resolve({ value, done: false, append: true, next: resolvable.promise });
125-
resolve = resolvable.resolve;
126-
reject = resolvable.reject;
127-
106+
innerStreamable.append(value);
128107
warnUnclosedStream();
129108

130109
return streamable;
@@ -136,7 +115,7 @@ function createStreamableUI(initialValue?: React.ReactNode) {
136115
clearTimeout(warningTimeout);
137116
}
138117
closed = true;
139-
reject(error);
118+
innerStreamable.error(error);
140119

141120
return streamable;
142121
},
@@ -148,11 +127,11 @@ function createStreamableUI(initialValue?: React.ReactNode) {
148127
}
149128
closed = true;
150129
if (args.length) {
151-
resolve({ value: args[0], done: true });
130+
innerStreamable.done(args[0]);
152131
return streamable;
153132
}
154-
resolve({ value: currentValue, done: true });
155133

134+
innerStreamable.done();
156135
return streamable;
157136
},
158137
};
@@ -377,31 +356,45 @@ function createStreamableValueImpl<T = any, E = any>(initialValue?: T) {
377356
append(value: T) {
378357
assertStream('.append()');
379358

380-
if (
381-
typeof currentValue !== 'string' &&
382-
typeof currentValue !== 'undefined'
383-
) {
359+
if (typeof value !== 'string' && !isValidElement(value)) {
384360
throw new Error(
385-
`.append(): The current value is not a string. Received: ${typeof currentValue}`,
361+
`.append(): The value type can't be appended to the stream. Received: ${typeof value}`,
386362
);
387363
}
388-
if (typeof value !== 'string') {
364+
365+
if (typeof currentValue === 'undefined') {
366+
currentPatchValue = undefined;
367+
currentValue = value;
368+
} else if (typeof currentValue === 'string') {
369+
if (typeof value === 'string') {
370+
currentPatchValue = [0, value];
371+
(currentValue as string) = currentValue + value;
372+
} else {
373+
currentPatchValue = [1, value];
374+
(currentValue as unknown as ReactElement) = (
375+
<>
376+
{currentValue}
377+
{value}
378+
</>
379+
);
380+
}
381+
} else if (isValidElement(currentValue)) {
382+
currentPatchValue = [1, value];
383+
(currentValue as ReactElement) = (
384+
<>
385+
{currentValue}
386+
{value}
387+
</>
388+
);
389+
} else {
389390
throw new Error(
390-
`.append(): The value is not a string. Received: ${typeof value}`,
391+
`.append(): The current value doesn't support appending data. Type: ${typeof currentValue}`,
391392
);
392393
}
393394

394395
const resolvePrevious = resolvable.resolve;
395396
resolvable = createResolvablePromise();
396397

397-
if (typeof currentValue === 'string') {
398-
currentPatchValue = [0, value];
399-
(currentValue as string) = currentValue + value;
400-
} else {
401-
currentPatchValue = undefined;
402-
currentValue = value;
403-
}
404-
405398
currentPromise = resolvable.promise;
406399
resolvePrevious(createWrapped());
407400

‎packages/core/rsc/streamable.ui.test.tsx

-484
This file was deleted.

‎packages/core/rsc/types.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ export type MutableAIState<AIState> = {
8787
done: ((newState: AIState) => void) | (() => void);
8888
};
8989

90-
export type StreamablePatch = undefined | [0, string]; // Append string.
90+
export type StreamablePatch =
91+
| undefined
92+
| [0, string] // Append string
93+
| [1, React.ReactElement | string]; // Append element
9194

9295
declare const __internal_curr: unique symbol;
9396
declare const __internal_error: unique symbol;

‎packages/core/rsc/utils.tsx

-52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import React, { Suspense } from 'react';
2-
31
export function createResolvablePromise<T = any>() {
42
let resolve: (value: T) => void, reject: (error: unknown) => void;
53
const promise = new Promise<T>((res, rej) => {
@@ -13,56 +11,6 @@ export function createResolvablePromise<T = any>() {
1311
};
1412
}
1513

16-
// Use the name `R` for `Row` as it will be shorter in the RSC payload.
17-
const R = [
18-
(async ({
19-
c, // current
20-
n, // next
21-
}: {
22-
c: React.ReactNode;
23-
n: Promise<any>;
24-
}) => {
25-
const chunk = await n;
26-
if (chunk.done) {
27-
return chunk.value;
28-
}
29-
30-
if (chunk.append) {
31-
return (
32-
<>
33-
{c}
34-
<Suspense fallback={chunk.value}>
35-
<R c={chunk.value} n={chunk.next} />
36-
</Suspense>
37-
</>
38-
);
39-
}
40-
41-
return (
42-
<Suspense fallback={chunk.value}>
43-
<R c={chunk.value} n={chunk.next} />
44-
</Suspense>
45-
);
46-
}) as unknown as React.FC<{
47-
c: React.ReactNode;
48-
n: Promise<any>;
49-
}>,
50-
][0];
51-
52-
export function createSuspensedChunk(initialValue: React.ReactNode) {
53-
const { promise, resolve, reject } = createResolvablePromise();
54-
55-
return {
56-
row: (
57-
<Suspense fallback={initialValue}>
58-
<R c={initialValue} n={promise} />
59-
</Suspense>
60-
) as React.ReactNode,
61-
resolve,
62-
reject,
63-
};
64-
}
65-
6614
export const isFunction = (x: unknown): x is Function =>
6715
typeof x === 'function';
6816

‎packages/core/tests/e2e/next-server/app/rsc/client-utils.js

-5
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
'use server';
2+
3+
import { streamUI } from 'ai/rsc';
4+
import { z } from 'zod';
5+
6+
import { MockLanguageModelV1 } from '../../../../../core/test/mock-language-model-v1';
7+
import { convertArrayToReadableStream } from '../../../../../core/test/convert-array-to-readable-stream';
8+
9+
const mockTextModel = new MockLanguageModelV1({
10+
doStream: async () => {
11+
return {
12+
stream: convertArrayToReadableStream([
13+
{ type: 'text-delta', textDelta: `"Hello, ` },
14+
{ type: 'text-delta', textDelta: `world` },
15+
{ type: 'text-delta', textDelta: `!"` },
16+
]),
17+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
18+
};
19+
},
20+
});
21+
22+
const mockToolModel = new MockLanguageModelV1({
23+
doStream: async () => {
24+
return {
25+
stream: convertArrayToReadableStream([
26+
{
27+
type: 'tool-call',
28+
toolCallType: 'function',
29+
toolCallId: 'call-1',
30+
toolName: 'tool1',
31+
args: `{ "value": "value" }`,
32+
},
33+
]),
34+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
35+
};
36+
},
37+
});
38+
39+
function sleep(ms = 0) {
40+
return new Promise(resolve => setTimeout(resolve, ms));
41+
}
42+
43+
export async function action(testCase) {
44+
switch (testCase) {
45+
case 'text': {
46+
const result = await streamUI({
47+
model: mockTextModel,
48+
prompt: '',
49+
});
50+
return result.value;
51+
}
52+
case 'wrapped-text': {
53+
const result = await streamUI({
54+
model: mockTextModel,
55+
prompt: '',
56+
text: ({ content }) => <p>AI: {content}</p>,
57+
});
58+
return result.value;
59+
}
60+
case 'tool': {
61+
const result = await streamUI({
62+
model: mockToolModel,
63+
prompt: '',
64+
tools: {
65+
tool1: {
66+
description: 'test tool 1',
67+
parameters: z.object({
68+
value: z.string(),
69+
}),
70+
generate: async function* ({ value }) {
71+
yield 'Loading...';
72+
await sleep(10);
73+
return <div>tool1: {value}</div>;
74+
},
75+
},
76+
},
77+
});
78+
return result.value;
79+
}
80+
}
81+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
'use client';
2+
3+
import { useState } from 'react';
4+
5+
export function Client({ action }) {
6+
const [log, setLog] = useState('');
7+
8+
return (
9+
<div>
10+
<pre id="log" style={{ border: '1px solid #ccc', padding: 5 }}>
11+
{log}
12+
</pre>
13+
14+
{/* Test suites */}
15+
<div style={{ display: 'inline-flex', flexDirection: 'column', gap: 5 }}>
16+
<button
17+
id="test-streamui-text"
18+
onClick={async () => {
19+
setLog(await action('text'));
20+
}}
21+
>
22+
Test streamUI() Text UI
23+
</button>
24+
<button
25+
id="test-streamui-wrapped-text"
26+
onClick={async () => {
27+
setLog(await action('wrapped-text'));
28+
}}
29+
>
30+
Test streamUI() Wrapped Text UI
31+
</button>
32+
<button
33+
id="test-streamui-tool"
34+
onClick={async () => {
35+
setLog(await action('tool'));
36+
}}
37+
>
38+
Test streamUI() Tool
39+
</button>
40+
</div>
41+
</div>
42+
);
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import { action } from './action';
2+
import { Client } from './client';
3+
4+
export default function Page() {
5+
return <Client action={action} />;
6+
}

‎packages/core/tests/e2e/next-server/app/rsc/actions.jsx ‎packages/core/tests/e2e/next-server/app/streamable/actions.jsx

+27-3
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,46 @@ function sleep(ms = 0) {
77
return new Promise(resolve => setTimeout(resolve, ms));
88
}
99

10+
export async function streamableUIError() {
11+
const streamable = createStreamableUI();
12+
(async () => {
13+
await sleep(10);
14+
streamable.update(<p>foo</p>);
15+
await sleep(10);
16+
streamable.error('This is an error');
17+
})();
18+
return streamable.value;
19+
}
20+
21+
export async function streamableUIAppend() {
22+
const streamable = createStreamableUI();
23+
(async () => {
24+
await sleep(10);
25+
streamable.update(<p>foo</p>);
26+
await sleep(10);
27+
streamable.append(<p>bar</p>);
28+
await sleep(10);
29+
streamable.done();
30+
})();
31+
return streamable.value;
32+
}
33+
1034
export async function streamableUI() {
1135
const streamable = createStreamableUI();
1236
(async () => {
13-
await sleep();
37+
await sleep(10);
1438
streamable.update(
1539
<ClientInfo>
1640
<p>I am a paragraph</p>
1741
</ClientInfo>,
1842
);
19-
await sleep();
43+
await sleep(10);
2044
streamable.update(
2145
<ClientInfo>
2246
<button>I am a button</button>
2347
</ClientInfo>,
2448
);
25-
await sleep();
49+
await sleep(10);
2650
streamable.done();
2751
})();
2852
return streamable.value;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
'use client';
2+
3+
import { useEffect, useState } from 'react';
4+
5+
export function ClientInfo({ children }) {
6+
const [renders, setRenders] = useState(0);
7+
8+
// Increment the render count whenever the children change.
9+
// This can be used to verify that the component is re-rendered with the
10+
// state kept.
11+
useEffect(() => {
12+
setRenders(r => r + 1);
13+
}, [children]);
14+
15+
return (
16+
<div>
17+
<p>{renders > 1 ? '(Rerendered) ' : ''}</p>
18+
{children}
19+
</div>
20+
);
21+
}

‎packages/core/tests/e2e/next-server/app/rsc/client.js ‎packages/core/tests/e2e/next-server/app/streamable/client.js

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
'use client';
22

3-
import { useState } from 'react';
3+
import React, { useState } from 'react';
44
import { readStreamableValue } from 'ai/rsc';
55

6+
class ErrorBoundary extends React.Component {
7+
constructor(props) {
8+
super(props);
9+
this.state = { error: null };
10+
}
11+
12+
static getDerivedStateFromError(error) {
13+
return { error };
14+
}
15+
16+
render() {
17+
if (this.state.error) {
18+
return <div>Caught by Error Boundary: {this.state.error}</div>;
19+
}
20+
return this.props.children;
21+
}
22+
}
23+
624
export function Client({ actions }) {
725
const [log, setLog] = useState('');
826

@@ -19,9 +37,20 @@ export function Client({ actions }) {
1937

2038
// Test `createStreamableUI` API
2139
async function testStreamableUI() {
40+
setLog(null);
2241
const value = await actions.streamableUI();
2342
setLog(value);
2443
}
44+
async function testStreamableUIAppend() {
45+
setLog(null);
46+
const value = await actions.streamableUIAppend();
47+
setLog(value);
48+
}
49+
async function testStreamableUIError() {
50+
setLog(null);
51+
const value = await actions.streamableUIError();
52+
setLog(<ErrorBoundary>{value}</ErrorBoundary>);
53+
}
2554

2655
return (
2756
<div>
@@ -37,6 +66,12 @@ export function Client({ actions }) {
3766
<button id="test-streamable-ui" onClick={testStreamableUI}>
3867
Test Streamable UI
3968
</button>
69+
<button id="test-streamable-ui-append" onClick={testStreamableUIAppend}>
70+
Test Streamable UI (Append)
71+
</button>
72+
<button id="test-streamable-ui-error" onClick={testStreamableUIError}>
73+
Test Streamable UI (Error)
74+
</button>
4075
</div>
4176
</div>
4277
);

‎packages/core/tests/e2e/next-server/app/rsc/page.js ‎packages/core/tests/e2e/next-server/app/streamable/page.js

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
import { streamableUI, streamableValue } from './actions';
1+
import {
2+
streamableUI,
3+
streamableValue,
4+
streamableUIAppend,
5+
streamableUIError,
6+
} from './actions';
27
import { Client } from './client';
38

49
export default function Page() {
@@ -7,6 +12,8 @@ export default function Page() {
712
actions={{
813
streamableUI,
914
streamableValue,
15+
streamableUIAppend,
16+
streamableUIError,
1017
}}
1118
/>
1219
);

‎packages/core/tests/e2e/next-server/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"next": "canary",
99
"react": "rc",
1010
"react-dom": "rc",
11+
"zod": "*",
1112
"ai": "workspace:*"
1213
}
1314
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { test, expect } from '@playwright/test';
2+
3+
test('streamUI() with text', async ({ page }) => {
4+
await page.goto('/stream-ui');
5+
await page.click('#test-streamui-text');
6+
7+
const logs = page.locator('#log');
8+
await expect(logs).toHaveText('"Hello, world!"');
9+
});
10+
11+
test('streamUI() with wrapped text', async ({ page }) => {
12+
await page.goto('/stream-ui');
13+
await page.click('#test-streamui-wrapped-text');
14+
15+
const logs = page.locator('#log');
16+
await expect(logs).toHaveText('AI: "Hello, world!"');
17+
});
18+
19+
test('streamUI() with tool call', async ({ page }) => {
20+
await page.goto('/stream-ui');
21+
await page.click('#test-streamui-tool');
22+
23+
const logs = page.locator('#log');
24+
await expect(logs).toHaveText('tool1: value');
25+
});
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { test, expect } from '@playwright/test';
22

3-
test('createStreamableValue and readStreamableValue', async ({ page }) => {
4-
await page.goto('/rsc');
3+
test('createStreamableValue() and readStreamableValue()', async ({ page }) => {
4+
await page.goto('/streamable');
55
await page.click('#test-streamable-value');
66

77
const logs = page.locator('#log');
@@ -10,10 +10,29 @@ test('createStreamableValue and readStreamableValue', async ({ page }) => {
1010
);
1111
});
1212

13-
test('test-streamable-ui', async ({ page }) => {
14-
await page.goto('/rsc');
13+
test('createStreamableUI()', async ({ page }) => {
14+
await page.goto('/streamable');
1515
await page.click('#test-streamable-ui');
1616

1717
const logs = page.locator('#log');
18-
await expect(logs).toHaveText('I am a button');
18+
19+
// It should update the UI but reuse the same component instance and its state
20+
// to avoid re-mounting.
21+
await expect(logs).toHaveText('(Rerendered) I am a button');
22+
});
23+
24+
test('createStreamableUI() .append() method', async ({ page }) => {
25+
await page.goto('/streamable');
26+
await page.click('#test-streamable-ui-append');
27+
28+
const logs = page.locator('#log');
29+
await expect(logs).toHaveText('foobar');
30+
});
31+
32+
test('createStreamableUI() .error() method', async ({ page }) => {
33+
await page.goto('/streamable');
34+
await page.click('#test-streamable-ui-error');
35+
36+
const logs = page.locator('#log');
37+
await expect(logs).toHaveText('Caught by Error Boundary: This is an error');
1938
});

‎pnpm-lock.yaml

+55-134
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.