Skip to content

Commit d1b1880

Browse files
authoredMay 8, 2024··
ai/core: allow reading the stream from streamText multiple times. (#1510)
1 parent b15b8af commit d1b1880

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed
 

‎.changeset/late-toys-perform.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix (ai/core): allow reading streams in streamText result multiple times

‎packages/core/core/generate-text/stream-text.test.ts

+53
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,56 @@ describe('result.toTextStreamResponse', () => {
439439
assert.deepStrictEqual(chunks, ['Hello', ', ', 'world!']);
440440
});
441441
});
442+
443+
describe('multiple stream consumption', () => {
444+
it('should support text stream, ai stream, full stream on single result object', async () => {
445+
const result = await streamText({
446+
model: new MockLanguageModelV1({
447+
doStream: async () => {
448+
return {
449+
stream: convertArrayToReadableStream([
450+
{ type: 'text-delta', textDelta: 'Hello' },
451+
{ type: 'text-delta', textDelta: ', ' },
452+
{ type: 'text-delta', textDelta: 'world!' },
453+
{
454+
type: 'finish',
455+
finishReason: 'stop',
456+
logprobs: undefined,
457+
usage: { completionTokens: 10, promptTokens: 3 },
458+
},
459+
]),
460+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
461+
};
462+
},
463+
}),
464+
prompt: 'test-input',
465+
});
466+
467+
assert.deepStrictEqual(
468+
await convertAsyncIterableToArray(result.textStream),
469+
['Hello', ', ', 'world!'],
470+
);
471+
472+
assert.deepStrictEqual(
473+
await convertReadableStreamToArray(
474+
result.toAIStream().pipeThrough(new TextDecoderStream()),
475+
),
476+
['0:"Hello"\n', '0:", "\n', '0:"world!"\n'],
477+
);
478+
479+
assert.deepStrictEqual(
480+
await convertAsyncIterableToArray(result.fullStream),
481+
[
482+
{ type: 'text-delta', textDelta: 'Hello' },
483+
{ type: 'text-delta', textDelta: ', ' },
484+
{ type: 'text-delta', textDelta: 'world!' },
485+
{
486+
type: 'finish',
487+
finishReason: 'stop',
488+
logprobs: undefined,
489+
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
490+
},
491+
],
492+
);
493+
});
494+
});

‎packages/core/core/generate-text/stream-text.ts

+17-3
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
140140
A result object for accessing different stream types and additional information.
141141
*/
142142
export class StreamTextResult<TOOLS extends Record<string, CoreTool>> {
143-
private readonly originalStream: ReadableStream<TextStreamPart<TOOLS>>;
143+
private originalStream: ReadableStream<TextStreamPart<TOOLS>>;
144144

145145
/**
146146
Warnings from the model provider (e.g. unsupported settings)
@@ -173,13 +173,27 @@ Response headers.
173173
this.rawResponse = rawResponse;
174174
}
175175

176+
/**
177+
Split out a new stream from the original stream.
178+
The original stream is replaced to allow for further splitting,
179+
since we do not know how many times the stream will be split.
180+
181+
Note: this leads to buffering the stream content on the server.
182+
However, the LLM results are expected to be small enough to not cause issues.
183+
*/
184+
private teeStream() {
185+
const [stream1, stream2] = this.originalStream.tee();
186+
this.originalStream = stream2;
187+
return stream1;
188+
}
189+
176190
/**
177191
A text stream that returns only the generated text deltas. You can use it
178192
as either an AsyncIterable or a ReadableStream. When an error occurs, the
179193
stream will throw the error.
180194
*/
181195
get textStream(): AsyncIterableStream<string> {
182-
return createAsyncIterableStream(this.originalStream, {
196+
return createAsyncIterableStream(this.teeStream(), {
183197
transform(chunk, controller) {
184198
if (chunk.type === 'text-delta') {
185199
// do not stream empty text deltas:
@@ -200,7 +214,7 @@ You can use it as either an AsyncIterable or a ReadableStream. When an error occ
200214
stream will throw the error.
201215
*/
202216
get fullStream(): AsyncIterableStream<TextStreamPart<TOOLS>> {
203-
return createAsyncIterableStream(this.originalStream, {
217+
return createAsyncIterableStream(this.teeStream(), {
204218
transform(chunk, controller) {
205219
if (chunk.type === 'text-delta') {
206220
// do not stream empty text deltas:

0 commit comments

Comments
 (0)
Please sign in to comment.