Skip to content

Commit a085d42

Browse files
authoredMay 16, 2024··
fix (ai/ui): decouple StreamData chunks from LLM stream (#1613)
1 parent 1659aba commit a085d42

File tree

11 files changed

+337
-74
lines changed

11 files changed

+337
-74
lines changed
 

‎.changeset/quick-drinks-sort.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix (ai/ui): decouple StreamData chunks from LLM stream

‎examples/next-openai/app/api/chat/route.ts

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { openai } from '@ai-sdk/openai';
22
import { streamText } from 'ai';
33

44
export const dynamic = 'force-dynamic';
5+
export const maxDuration = 60;
56

67
export async function POST(req: Request) {
78
// Extract the `messages` from the body of the request

‎examples/next-openai/app/api/completion/route.ts

+5-3
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,27 @@ import { openai } from '@ai-sdk/openai';
22
import { StreamData, StreamingTextResponse, streamText } from 'ai';
33

44
export const dynamic = 'force-dynamic';
5+
export const maxDuration = 60;
56

67
export async function POST(req: Request) {
78
// Extract the `prompt` from the body of the request
89
const { prompt } = await req.json();
910

1011
const result = await streamText({
11-
model: openai.completion('gpt-3.5-turbo-instruct'),
12+
model: openai('gpt-3.5-turbo-instruct'),
1213
maxTokens: 2000,
1314
prompt,
1415
});
1516

1617
// optional: use stream data
1718
const data = new StreamData();
1819

19-
data.append({ test: 'value' });
20+
data.append('call started');
2021

21-
// Convert the response into a friendly text-stream
22+
// Convert the response to an AI data stream
2223
const stream = result.toAIStream({
2324
onFinal(completion) {
25+
data.append('call completed');
2426
data.close();
2527
},
2628
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { StreamData, StreamingTextResponse, streamText } from 'ai';
3+
4+
export const dynamic = 'force-dynamic';
5+
export const maxDuration = 60;
6+
7+
export async function POST(req: Request) {
8+
const { messages } = await req.json();
9+
10+
const result = await streamText({
11+
model: openai('gpt-4-turbo'),
12+
messages,
13+
});
14+
15+
// optional: use stream data
16+
const data = new StreamData();
17+
18+
data.append('initialized call');
19+
20+
return new StreamingTextResponse(
21+
result.toAIStream({
22+
onFinal() {
23+
data.append('call completed');
24+
data.close();
25+
},
26+
}),
27+
{},
28+
data,
29+
);
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
'use client';
2+
3+
import { Message, useChat } from 'ai/react';
4+
5+
export default function Chat() {
6+
const { messages, input, handleInputChange, handleSubmit, data } = useChat({
7+
api: '/api/use-chat-streamdata',
8+
});
9+
10+
return (
11+
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
12+
{data && (
13+
<pre className="p-4 text-sm bg-gray-100">
14+
{JSON.stringify(data, null, 2)}
15+
</pre>
16+
)}
17+
18+
{messages?.map((m: Message) => (
19+
<div key={m.id} className="whitespace-pre-wrap">
20+
<strong>{`${m.role}: `}</strong>
21+
{m.content}
22+
<br />
23+
<br />
24+
</div>
25+
))}
26+
27+
<form onSubmit={handleSubmit}>
28+
<input
29+
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
30+
value={input}
31+
placeholder="Say something..."
32+
onChange={handleInputChange}
33+
/>
34+
</form>
35+
</div>
36+
);
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import { expect, it } from 'vitest';
2+
import { mergeStreams } from './merge-streams';
3+
import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array';
4+
import { convertArrayToReadableStream } from '../test/convert-array-to-readable-stream';
5+
6+
it('should prioritize the first stream over the second stream', async () => {
7+
const stream1 = convertArrayToReadableStream(['1a', '1b', '1c']);
8+
const stream2 = convertArrayToReadableStream(['2a', '2b', '2c']);
9+
10+
const mergedStream = mergeStreams(stream1, stream2);
11+
12+
expect(await convertReadableStreamToArray(mergedStream)).toEqual([
13+
'1a',
14+
'1b',
15+
'1c',
16+
'2a',
17+
'2b',
18+
'2c',
19+
]);
20+
});
21+
22+
it('should return values from the 2nd stream until the 1st stream has values', async () => {
23+
let stream1Controller: ReadableStreamDefaultController<string> | undefined;
24+
const stream1 = new ReadableStream({
25+
start(controller) {
26+
stream1Controller = controller;
27+
},
28+
});
29+
30+
let stream2Controller: ReadableStreamDefaultController<string> | undefined;
31+
const stream2 = new ReadableStream({
32+
start(controller) {
33+
stream2Controller = controller;
34+
},
35+
});
36+
37+
const mergedStream = mergeStreams(stream1, stream2);
38+
39+
const result: string[] = [];
40+
const reader = mergedStream.getReader();
41+
42+
async function pull() {
43+
const { value, done } = await reader.read();
44+
result.push(value!);
45+
}
46+
47+
stream2Controller!.enqueue('2a');
48+
stream2Controller!.enqueue('2b');
49+
50+
await pull();
51+
await pull();
52+
53+
stream2Controller!.enqueue('2c');
54+
stream2Controller!.enqueue('2d'); // comes later
55+
stream1Controller!.enqueue('1a');
56+
stream2Controller!.enqueue('2e'); // comes later
57+
stream1Controller!.enqueue('1b');
58+
stream1Controller!.enqueue('1c');
59+
stream2Controller!.enqueue('2f');
60+
61+
await pull();
62+
await pull();
63+
await pull();
64+
await pull();
65+
await pull();
66+
67+
stream1Controller!.close();
68+
stream2Controller!.close();
69+
70+
await pull();
71+
await pull();
72+
73+
expect(result).toEqual([
74+
'2a',
75+
'2b',
76+
'2c',
77+
'1a',
78+
'1b',
79+
'1c',
80+
'2d',
81+
'2e',
82+
'2f',
83+
]);
84+
});
+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/**
2+
* Merges two readable streams into a single readable stream, emitting values
3+
* from each stream as they become available.
4+
*
5+
* The first stream is prioritized over the second stream. If both streams have
6+
* values available, the first stream's value is emitted first.
7+
*
8+
* @template VALUE1 - The type of values emitted by the first stream.
9+
* @template VALUE2 - The type of values emitted by the second stream.
10+
* @param {ReadableStream<VALUE1>} stream1 - The first readable stream.
11+
* @param {ReadableStream<VALUE2>} stream2 - The second readable stream.
12+
* @returns {ReadableStream<VALUE1 | VALUE2>} A new readable stream that emits values from both input streams.
13+
*/
14+
export function mergeStreams<VALUE1, VALUE2>(
15+
stream1: ReadableStream<VALUE1>,
16+
stream2: ReadableStream<VALUE2>,
17+
): ReadableStream<VALUE1 | VALUE2> {
18+
const reader1 = stream1.getReader();
19+
const reader2 = stream2.getReader();
20+
21+
let lastRead1: Promise<ReadableStreamReadResult<VALUE1>> | undefined =
22+
undefined;
23+
let lastRead2: Promise<ReadableStreamReadResult<VALUE2>> | undefined =
24+
undefined;
25+
26+
let stream1Done = false;
27+
let stream2Done = false;
28+
29+
// only use when stream 2 is done:
30+
async function readStream1(
31+
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>,
32+
) {
33+
try {
34+
if (lastRead1 == null) {
35+
lastRead1 = reader1.read();
36+
}
37+
38+
const result = await lastRead1;
39+
lastRead1 = undefined;
40+
41+
if (!result.done) {
42+
controller.enqueue(result.value);
43+
} else {
44+
controller.close();
45+
}
46+
} catch (error) {
47+
controller.error(error);
48+
}
49+
}
50+
51+
// only use when stream 1 is done:
52+
async function readStream2(
53+
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>,
54+
) {
55+
try {
56+
if (lastRead2 == null) {
57+
lastRead2 = reader2.read();
58+
}
59+
60+
const result = await lastRead2;
61+
lastRead2 = undefined;
62+
63+
if (!result.done) {
64+
controller.enqueue(result.value);
65+
} else {
66+
controller.close();
67+
}
68+
} catch (error) {
69+
controller.error(error);
70+
}
71+
}
72+
73+
return new ReadableStream<VALUE1 | VALUE2>({
74+
async pull(controller) {
75+
try {
76+
// stream 1 is done, we can only read from stream 2:
77+
if (stream1Done) {
78+
readStream2(controller);
79+
return;
80+
}
81+
82+
// stream 2 is done, we can only read from stream 1:
83+
if (stream2Done) {
84+
readStream1(controller);
85+
return;
86+
}
87+
88+
// pull the next value from the stream that was read last:
89+
if (lastRead1 == null) {
90+
lastRead1 = reader1.read();
91+
}
92+
if (lastRead2 == null) {
93+
lastRead2 = reader2.read();
94+
}
95+
96+
// Note on Promise.race (prioritizing stream 1 over stream 2):
97+
// If the iterable contains one or more non-promise values and/or an already settled promise,
98+
// then Promise.race() will settle to the first of these values found in the iterable.
99+
const { result, reader } = await Promise.race([
100+
lastRead1.then(result => ({ result, reader: reader1 })),
101+
lastRead2.then(result => ({ result, reader: reader2 })),
102+
]);
103+
104+
if (!result.done) {
105+
controller.enqueue(result.value);
106+
}
107+
108+
if (reader === reader1) {
109+
lastRead1 = undefined;
110+
if (result.done) {
111+
// stream 1 is done, we can only read from stream 2:
112+
readStream2(controller);
113+
stream1Done = true;
114+
}
115+
} else {
116+
lastRead2 = undefined;
117+
// stream 2 is done, we can only read from stream 1:
118+
if (result.done) {
119+
stream2Done = true;
120+
readStream1(controller);
121+
}
122+
}
123+
} catch (error) {
124+
controller.error(error);
125+
}
126+
},
127+
cancel() {
128+
reader1.cancel();
129+
reader2.cancel();
130+
},
131+
});
132+
}

‎packages/core/streams/inkeep-stream.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ describe('InkeepStream', () => {
8787
'0:","\n',
8888
'0:" world"\n',
8989
'0:"."\n',
90-
`2:[{"onFinalMetadata":{"chat_session_id":"12345",${recordsCitedSerialized}}}]\n`,
9190
`8:[{${recordsCitedSerialized}}]\n`,
91+
`2:[{"onFinalMetadata":{"chat_session_id":"12345",${recordsCitedSerialized}}}]\n`,
9292
]);
9393
});
9494
});

‎packages/core/streams/stream-data.ts

+38-68
Original file line numberDiff line numberDiff line change
@@ -7,80 +7,33 @@ import { JSONValue } from '../shared/types';
77
export class StreamData {
88
private encoder = new TextEncoder();
99

10-
private controller: TransformStreamDefaultController<Uint8Array> | null =
11-
null;
12-
public stream: TransformStream<Uint8Array, Uint8Array>;
13-
14-
// closing the stream is synchronous, but we want to return a promise
15-
// in case we're doing async work
16-
private isClosedPromise: Promise<void> | null = null;
17-
private isClosedPromiseResolver: undefined | (() => void) = undefined;
18-
private isClosed: boolean = false;
10+
private controller: ReadableStreamController<Uint8Array> | null = null;
11+
public stream: ReadableStream<Uint8Array>;
1912

20-
// array to store appended data
21-
private data: JSONValue[] = [];
22-
private messageAnnotations: JSONValue[] = [];
13+
private isClosed: boolean = false;
14+
private warningTimeout: NodeJS.Timeout | null = null;
2315

2416
constructor() {
25-
this.isClosedPromise = new Promise(resolve => {
26-
this.isClosedPromiseResolver = resolve;
27-
});
28-
2917
const self = this;
30-
this.stream = new TransformStream({
18+
19+
this.stream = new ReadableStream({
3120
start: async controller => {
3221
self.controller = controller;
33-
},
34-
transform: async (chunk, controller) => {
35-
// add buffered data to the stream
36-
if (self.data.length > 0) {
37-
const encodedData = self.encoder.encode(
38-
formatStreamPart('data', self.data),
39-
);
40-
self.data = [];
41-
controller.enqueue(encodedData);
42-
}
4322

44-
if (self.messageAnnotations.length) {
45-
const encodedMessageAnnotations = self.encoder.encode(
46-
formatStreamPart('message_annotations', self.messageAnnotations),
47-
);
48-
self.messageAnnotations = [];
49-
controller.enqueue(encodedMessageAnnotations);
23+
// Set a timeout to show a warning if the stream is not closed within 3 seconds
24+
if (process.env.NODE_ENV === 'development') {
25+
self.warningTimeout = setTimeout(() => {
26+
console.warn(
27+
'The data stream is hanging. Did you forget to close it with `data.close()`?',
28+
);
29+
}, 3000);
5030
}
51-
52-
controller.enqueue(chunk);
5331
},
54-
async flush(controller) {
55-
// Show a warning during dev if the data stream is hanging after 3 seconds.
56-
const warningTimeout =
57-
process.env.NODE_ENV === 'development'
58-
? setTimeout(() => {
59-
console.warn(
60-
'The data stream is hanging. Did you forget to close it with `data.close()`?',
61-
);
62-
}, 3000)
63-
: null;
64-
65-
await self.isClosedPromise;
66-
67-
if (warningTimeout !== null) {
68-
clearTimeout(warningTimeout);
69-
}
70-
71-
if (self.data.length) {
72-
const encodedData = self.encoder.encode(
73-
formatStreamPart('data', self.data),
74-
);
75-
controller.enqueue(encodedData);
76-
}
77-
78-
if (self.messageAnnotations.length) {
79-
const encodedData = self.encoder.encode(
80-
formatStreamPart('message_annotations', self.messageAnnotations),
81-
);
82-
controller.enqueue(encodedData);
83-
}
32+
pull: controller => {
33+
// No-op: we don't need to do anything special on pull
34+
},
35+
cancel: reason => {
36+
this.isClosed = true;
8437
},
8538
});
8639
}
@@ -94,24 +47,41 @@ export class StreamData {
9447
throw new Error('Stream controller is not initialized.');
9548
}
9649

97-
this.isClosedPromiseResolver?.();
50+
this.controller.close();
9851
this.isClosed = true;
52+
53+
// Clear the warning timeout if the stream is closed
54+
if (this.warningTimeout) {
55+
clearTimeout(this.warningTimeout);
56+
}
9957
}
10058

10159
append(value: JSONValue): void {
10260
if (this.isClosed) {
10361
throw new Error('Data Stream has already been closed.');
10462
}
10563

106-
this.data.push(value);
64+
if (!this.controller) {
65+
throw new Error('Stream controller is not initialized.');
66+
}
67+
68+
this.controller.enqueue(
69+
this.encoder.encode(formatStreamPart('data', [value])),
70+
);
10771
}
10872

10973
appendMessageAnnotation(value: JSONValue): void {
11074
if (this.isClosed) {
11175
throw new Error('Data Stream has already been closed.');
11276
}
11377

114-
this.messageAnnotations.push(value);
78+
if (!this.controller) {
79+
throw new Error('Stream controller is not initialized.');
80+
}
81+
82+
this.controller.enqueue(
83+
this.encoder.encode(formatStreamPart('message_annotations', [value])),
84+
);
11585
}
11686
}
11787

‎packages/core/streams/streaming-react-response.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* between the rows, but flushing the full payload on each row.
99
*/
1010

11+
import { mergeStreams } from '../core/util/merge-streams';
1112
import { parseComplexResponse } from '../shared/parse-complex-response';
1213
import { IdGenerator, JSONValue } from '../shared/types';
1314
import { nanoid } from '../shared/utils';
@@ -50,7 +51,7 @@ export class experimental_StreamingReactResponse {
5051
});
5152

5253
const processedStream: ReadableStream<Uint8Array> =
53-
options?.data != null ? res.pipeThrough(options?.data?.stream) : res;
54+
options?.data != null ? mergeStreams(options?.data?.stream, res) : res;
5455

5556
let lastPayload: Payload | undefined = undefined;
5657

‎packages/core/streams/streaming-text-response.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import type { ServerResponse } from 'node:http';
22
import { StreamData } from './stream-data';
3+
import { mergeStreams } from '../core/util/merge-streams';
34

45
/**
56
* A utility class for streaming text responses.
@@ -9,7 +10,7 @@ export class StreamingTextResponse extends Response {
910
let processedStream = res;
1011

1112
if (data) {
12-
processedStream = res.pipeThrough(data.stream);
13+
processedStream = mergeStreams(data.stream, res);
1314
}
1415

1516
super(processedStream as any, {

0 commit comments

Comments
 (0)
Please sign in to comment.