3
3
LanguageModelV1StreamPart ,
4
4
} from '@ai-sdk/provider' ;
5
5
import { z } from 'zod' ;
6
- import { calculateTokenUsage } from '../generate-text/token-usage' ;
6
+ import { TokenUsage , calculateTokenUsage } from '../generate-text/token-usage' ;
7
7
import { CallSettings } from '../prompt/call-settings' ;
8
8
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt' ;
9
9
import { getValidatedPrompt } from '../prompt/get-validated-prompt' ;
@@ -230,7 +230,7 @@ Default and recommended: 'auto' (best mode for the model).
230
230
} ) ;
231
231
}
232
232
233
- export type ObjectStreamPartInput =
233
+ export type ObjectStreamInputPart =
234
234
| {
235
235
type : 'error' ;
236
236
error : unknown ;
@@ -247,7 +247,7 @@ export type ObjectStreamPartInput =
247
247
} ;
248
248
249
249
export type ObjectStreamPart < T > =
250
- | ObjectStreamPartInput
250
+ | ObjectStreamInputPart
251
251
| {
252
252
type : 'object' ;
253
253
object : DeepPartial < T > ;
@@ -257,15 +257,18 @@ export type ObjectStreamPart<T> =
257
257
The result of a `streamObject` call that contains the partial object stream and additional information.
258
258
*/
259
259
export class StreamObjectResult < T > {
260
- private readonly originalStream : ReadableStream <
261
- string | ObjectStreamPartInput
262
- > ;
260
+ readonly originalStream : ReadableStream < ObjectStreamPart < T > > ;
263
261
264
262
/**
265
263
Warnings from the model provider (e.g. unsupported settings)
266
264
*/
267
265
readonly warnings : CallWarning [ ] | undefined ;
268
266
267
+ /**
268
+ The token usage of the generated response. Resolved when the response is finished.
269
+ */
270
+ readonly usage : Promise < TokenUsage > ;
271
+
269
272
/**
270
273
Optional raw response data.
271
274
*/
@@ -281,75 +284,105 @@ Response headers.
281
284
warnings,
282
285
rawResponse,
283
286
} : {
284
- stream : ReadableStream < string | ObjectStreamPartInput > ;
287
+ stream : ReadableStream < string | ObjectStreamInputPart > ;
285
288
warnings : CallWarning [ ] | undefined ;
286
289
rawResponse ?: {
287
290
headers ?: Record < string , string > ;
288
291
} ;
289
292
} ) {
290
- this . originalStream = stream ;
291
293
this . warnings = warnings ;
292
294
this . rawResponse = rawResponse ;
293
- }
294
-
295
- get partialObjectStream ( ) : AsyncIterableStream < DeepPartial < T > > {
296
- let accumulatedText = '' ;
297
- let latestObject : DeepPartial < T > | undefined = undefined ;
298
-
299
- return createAsyncIterableStream ( this . originalStream , {
300
- transform ( chunk , controller ) {
301
- if ( typeof chunk === 'string' ) {
302
- accumulatedText += chunk ;
303
-
304
- const currentObject = parsePartialJson (
305
- accumulatedText ,
306
- ) as DeepPartial < T > ;
307
295
308
- if ( ! isDeepEqualData ( latestObject , currentObject ) ) {
309
- latestObject = currentObject ;
310
-
311
- controller . enqueue ( currentObject ) ;
312
- }
313
- } else if ( chunk . type === 'error' ) {
314
- throw chunk . error ;
315
- }
316
- } ,
296
+ // initialize usage promise
297
+ let resolveUsage : ( value : TokenUsage | PromiseLike < TokenUsage > ) => void ;
298
+ this . usage = new Promise < TokenUsage > ( resolve => {
299
+ resolveUsage = resolve ;
317
300
} ) ;
318
- }
319
301
320
- get fullStream ( ) : AsyncIterableStream < ObjectStreamPart < T > > {
302
+ // store information for onFinish callback:
303
+ let usage : TokenUsage | undefined ;
304
+
305
+ // pipe chunks through a transformation stream that extracts metadata:
321
306
let accumulatedText = '' ;
322
307
let latestObject : DeepPartial < T > | undefined = undefined ;
323
308
324
- return createAsyncIterableStream ( this . originalStream , {
325
- transform ( chunk , controller ) {
326
- if ( typeof chunk === 'string' ) {
327
- accumulatedText += chunk ;
328
- const currentObject = parsePartialJson (
329
- accumulatedText ,
330
- ) as DeepPartial < T > ;
309
+ this . originalStream = stream . pipeThrough (
310
+ new TransformStream < string | ObjectStreamInputPart , ObjectStreamPart < T > > ( {
311
+ async transform ( chunk , controller ) : Promise < void > {
312
+ // process partial text chunks
313
+ if ( typeof chunk === 'string' ) {
314
+ accumulatedText += chunk ;
315
+
316
+ const currentObject = parsePartialJson (
317
+ accumulatedText ,
318
+ ) as DeepPartial < T > ;
319
+
320
+ if ( ! isDeepEqualData ( latestObject , currentObject ) ) {
321
+ latestObject = currentObject ;
331
322
332
- if ( ! isDeepEqualData ( latestObject , currentObject ) ) {
333
- latestObject = currentObject ;
323
+ controller . enqueue ( { type : 'object' , object : currentObject } ) ;
324
+ }
334
325
335
- controller . enqueue ( { type : 'object' , object : currentObject } ) ;
326
+ return ;
336
327
}
337
- } else {
328
+
338
329
switch ( chunk . type ) {
339
- case 'finish' :
330
+ case 'finish' : {
331
+ // store usage for promises and onFinish callback:
332
+ usage = calculateTokenUsage ( chunk . usage ) ;
333
+
340
334
controller . enqueue ( {
341
335
...chunk ,
342
- usage : calculateTokenUsage ( chunk . usage ) ,
336
+ usage,
343
337
} ) ;
338
+
339
+ // resolve promises that can be resolved now:
340
+ resolveUsage ( usage ) ;
341
+
344
342
break ;
345
- default :
343
+ }
344
+
345
+ default : {
346
346
controller . enqueue ( chunk ) ;
347
347
break ;
348
+ }
349
+ }
350
+ } ,
351
+ } ) ,
352
+ ) ;
353
+ }
354
+
355
+ get partialObjectStream ( ) : AsyncIterableStream < DeepPartial < T > > {
356
+ return createAsyncIterableStream ( this . originalStream , {
357
+ transform ( chunk , controller ) {
358
+ switch ( chunk . type ) {
359
+ case 'object' :
360
+ controller . enqueue ( chunk . object ) ;
361
+ break ;
362
+
363
+ case 'finish' :
364
+ break ;
365
+
366
+ case 'error' :
367
+ controller . error ( chunk . error ) ;
368
+ break ;
369
+
370
+ default : {
371
+ const _exhaustiveCheck : never = chunk ;
372
+ throw new Error ( `Unsupported chunk type: ${ _exhaustiveCheck } ` ) ;
348
373
}
349
374
}
350
375
} ,
351
376
} ) ;
352
377
}
378
+
379
+ get fullStream ( ) : AsyncIterableStream < ObjectStreamPart < T > > {
380
+ return createAsyncIterableStream ( this . originalStream , {
381
+ transform ( chunk , controller ) {
382
+ controller . enqueue ( chunk ) ;
383
+ } ,
384
+ } ) ;
385
+ }
353
386
}
354
387
355
388
/**
0 commit comments