Skip to content

Commit 28427d3

Browse files
authoredJun 6, 2024··
feat (core): add streamObject onFinish callback. (#1864)
1 parent c363148 commit 28427d3

File tree

9 files changed

+404
-9
lines changed

9 files changed

+404
-9
lines changed
 

‎.changeset/blue-pets-punch.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (core): add streamObject onFinish callback

‎content/docs/07-reference/ai-sdk-core/04-stream-object.mdx

+78-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,84 @@ for await (const partialObject of partialObjectStream) {
305305
isOptional: true,
306306
description:
307307
'An optional abort signal that can be used to cancel the call.'
308-
}
308+
},
309+
{
310+
name: 'onFinish',
311+
type: '(result: OnFinishResult) => void',
312+
isOptional: true,
313+
description:
314+
'Callback that is called when the LLM response and all request tool executions (for tools that have an `execute` function) are finished.',
315+
properties: [
316+
{
317+
type: 'OnFinishResult',
318+
parameters: [
319+
{
320+
name: 'usage',
321+
type: 'TokenUsage',
322+
description: 'The token usage of the generated text.',
323+
properties: [
324+
{
325+
type: 'TokenUsage',
326+
parameters: [
327+
{
328+
name: 'promptTokens',
329+
type: 'number',
330+
description: 'The total number of tokens in the prompt.',
331+
},
332+
{
333+
name: 'completionTokens',
334+
type: 'number',
335+
description:
336+
'The total number of tokens in the completion.',
337+
},
338+
{
339+
name: 'totalTokens',
340+
type: 'number',
341+
description: 'The total number of tokens generated.',
342+
},
343+
],
344+
},
345+
],
346+
},
347+
{
348+
name: 'object',
349+
type: 'T | undefined',
350+
description:
351+
'The generated object (typed according to the schema). Can be undefined if the final object does not match the schema.',
352+
},
353+
{
354+
name:"error",
355+
type:"unknown | undefined",
356+
description:"Optional error object. This is e.g. a TypeValidationError when the final object does not match the schema."
357+
},
358+
{
359+
name: 'warnings',
360+
type: 'Warning[] | undefined',
361+
description:
362+
'Warnings from the model provider (e.g. unsupported settings).',
363+
},
364+
{
365+
name: 'rawResponse',
366+
type: 'RawResponse',
367+
description: 'Optional raw response data.',
368+
properties: [
369+
{
370+
type: 'RawResponse',
371+
parameters: [
372+
{
373+
name: 'header',
374+
optional: true,
375+
type: 'Record<string, string>',
376+
description: 'Response headers.',
377+
},
378+
],
379+
},
380+
],
381+
},
382+
],
383+
},
384+
],
385+
},
309386

310387
]}
311388
/>

‎content/examples/03-node/02-streaming-structured-data/10-token-usage.mdx

+30-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,36 @@ description: Examples of how to record token usage when streaming structured dat
55

66
# Recording Token Usage
77

8-
When you're streaming structured data, you may want to record the token usage for billing purposes.
8+
When you're streaming structured data with [`streamObject`](/docs/reference/ai-sdk-core/stream-object),
9+
you may want to record the token usage for billing purposes.
10+
11+
## `onFinish` Callback
12+
13+
You can use the `onFinish` callback to record token usage.
14+
It is called when the stream is finished.
15+
16+
```ts file='index.ts' highlight={"15-17"}
17+
import { openai } from '@ai-sdk/openai';
18+
import { streamObject } from 'ai';
19+
import { z } from 'zod';
20+
21+
const result = await streamObject({
22+
model: openai('gpt-4-turbo'),
23+
schema: z.object({
24+
recipe: z.object({
25+
name: z.string(),
26+
ingredients: z.array(z.string()),
27+
steps: z.array(z.string()),
28+
}),
29+
}),
30+
prompt: 'Generate a lasagna recipe.',
31+
onFinish({ usage }) {
32+
console.log('Token usage:', usage);
33+
},
34+
});
35+
```
36+
37+
## `usage` Promise
938

1039
The [`streamObject`](/docs/reference/ai-sdk-core/stream-object) result contains a `usage` promise that resolves to the total token usage.
1140

‎content/examples/03-node/02-streaming-structured-data/12-object.mdx

+37
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,43 @@ description: Examples of how to record the final object when streaming structure
77

88
When you're streaming structured data, you may want to record the final object for logging or other purposes.
99

10+
## `onFinish` Callback
11+
12+
You can use the `onFinish` callback to record the final object.
13+
It is called when the stream is finished.
14+
15+
The `object` parameter contains the final object, or `undefined` if the type validation fails.
16+
There is also an `error` parameter that contains error when e.g. the object does not match the schema.
17+
18+
```ts file='index.ts' highlight={"15-23"}
19+
import { openai } from '@ai-sdk/openai';
20+
import { streamObject } from 'ai';
21+
import { z } from 'zod';
22+
23+
const result = await streamObject({
24+
model: openai('gpt-4-turbo'),
25+
schema: z.object({
26+
recipe: z.object({
27+
name: z.string(),
28+
ingredients: z.array(z.string()),
29+
steps: z.array(z.string()),
30+
}),
31+
}),
32+
prompt: 'Generate a lasagna recipe.',
33+
onFinish({ object, error }) {
34+
// handle type validation failure (when the object does not match the schema):
35+
if (object === undefined) {
36+
console.error('Error:', error);
37+
return;
38+
}
39+
40+
console.log('Final object:', JSON.stringify(object, null, 2));
41+
},
42+
});
43+
```
44+
45+
## `object` Promise
46+
1047
The [`streamObject`](/docs/reference/ai-sdk-core/stream-object) result contains an `object` promise that resolves to the final object.
1148
The object is fully typed. When the type validation according to the schema fails, the promise will be rejected with a `TypeValidationError`.
1249

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { streamObject } from 'ai';
3+
import dotenv from 'dotenv';
4+
import { z } from 'zod';
5+
6+
dotenv.config();
7+
8+
async function main() {
9+
const result = await streamObject({
10+
model: openai('gpt-4-turbo'),
11+
schema: z.object({
12+
characters: z.array(
13+
z.object({
14+
name: z.string(),
15+
class: z
16+
.string()
17+
.describe('Character class, e.g. warrior, mage, or thief.'),
18+
description: z.string(),
19+
}),
20+
),
21+
}),
22+
prompt:
23+
'Generate 3 character descriptions for a fantasy role playing game.',
24+
onFinish({ usage, object, error }) {
25+
console.log();
26+
console.log('onFinish');
27+
console.log('Token usage:', usage);
28+
29+
// handle type validation failure (when the object does not match the schema):
30+
if (object === undefined) {
31+
console.error('Error:', error);
32+
} else {
33+
console.log('Final object:', JSON.stringify(object, null, 2));
34+
}
35+
},
36+
});
37+
38+
// consume the partialObjectStream:
39+
for await (const partialObject of result.partialObjectStream) {
40+
}
41+
}
42+
43+
main().catch(console.error);

‎examples/ai-core/src/stream-object/openai.ts

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ dotenv.config();
88
async function main() {
99
const result = await streamObject({
1010
model: openai('gpt-4-turbo'),
11-
maxTokens: 2000,
1211
schema: z.object({
1312
characters: z.array(
1413
z.object({

‎examples/ai-core/src/stream-text/openai-on-finish.ts

-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import { openai } from '@ai-sdk/openai';
22
import { streamText } from 'ai';
33
import dotenv from 'dotenv';
4-
import { weatherTool } from '../tools/weather-tool';
54

65
dotenv.config();
76

87
async function main() {
98
const result = await streamText({
109
model: openai('gpt-4-turbo'),
11-
maxTokens: 128,
12-
temperature: 0.3,
13-
maxRetries: 5,
1410
prompt: 'Invent a new holiday and describe its traditions.',
1511
onFinish({
1612
usage,

‎packages/core/core/generate-object/stream-object.test.ts

+146
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,149 @@ describe('result.object', () => {
356356
});
357357
});
358358
});
359+
360+
describe('onFinish callback', () => {
361+
describe('with successfully validated object', () => {
362+
let result: Parameters<
363+
Required<Parameters<typeof streamObject>[0]>['onFinish']
364+
>[0];
365+
366+
beforeEach(async () => {
367+
const { partialObjectStream } = await streamObject({
368+
model: new MockLanguageModelV1({
369+
doStream: async ({ prompt, mode }) => {
370+
assert.deepStrictEqual(mode, { type: 'object-json' });
371+
assert.deepStrictEqual(prompt, [
372+
{
373+
role: 'system',
374+
content:
375+
'JSON schema:\n' +
376+
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
377+
'You MUST answer with a JSON object that matches the JSON schema above.',
378+
},
379+
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
380+
]);
381+
382+
return {
383+
stream: convertArrayToReadableStream([
384+
{ type: 'text-delta', textDelta: '{ ' },
385+
{ type: 'text-delta', textDelta: '"content": ' },
386+
{ type: 'text-delta', textDelta: `"Hello, ` },
387+
{ type: 'text-delta', textDelta: `world` },
388+
{ type: 'text-delta', textDelta: `!"` },
389+
{ type: 'text-delta', textDelta: ' }' },
390+
{
391+
type: 'finish',
392+
finishReason: 'stop',
393+
usage: { completionTokens: 10, promptTokens: 3 },
394+
},
395+
]),
396+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
397+
};
398+
},
399+
}),
400+
schema: z.object({ content: z.string() }),
401+
mode: 'json',
402+
prompt: 'prompt',
403+
onFinish: async event => {
404+
result = event as unknown as typeof result;
405+
},
406+
});
407+
408+
// consume stream
409+
await convertAsyncIterableToArray(partialObjectStream);
410+
});
411+
412+
it('should contain token usage', async () => {
413+
assert.deepStrictEqual(result.usage, {
414+
completionTokens: 10,
415+
promptTokens: 3,
416+
totalTokens: 13,
417+
});
418+
});
419+
420+
it('should contain the full object', async () => {
421+
assert.deepStrictEqual(result.object, {
422+
content: 'Hello, world!',
423+
});
424+
});
425+
426+
it('should not contain an error object', async () => {
427+
assert.deepStrictEqual(result.error, undefined);
428+
});
429+
});
430+
431+
describe("with object that doesn't match the schema", () => {
432+
let result: Parameters<
433+
Required<Parameters<typeof streamObject>[0]>['onFinish']
434+
>[0];
435+
436+
beforeEach(async () => {
437+
const { partialObjectStream, object } = await streamObject({
438+
model: new MockLanguageModelV1({
439+
doStream: async ({ prompt, mode }) => {
440+
assert.deepStrictEqual(mode, { type: 'object-json' });
441+
assert.deepStrictEqual(prompt, [
442+
{
443+
role: 'system',
444+
content:
445+
'JSON schema:\n' +
446+
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
447+
'You MUST answer with a JSON object that matches the JSON schema above.',
448+
},
449+
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
450+
]);
451+
452+
return {
453+
stream: convertArrayToReadableStream([
454+
{ type: 'text-delta', textDelta: '{ ' },
455+
{ type: 'text-delta', textDelta: '"invalid": ' },
456+
{ type: 'text-delta', textDelta: `"Hello, ` },
457+
{ type: 'text-delta', textDelta: `world` },
458+
{ type: 'text-delta', textDelta: `!"` },
459+
{ type: 'text-delta', textDelta: ' }' },
460+
{
461+
type: 'finish',
462+
finishReason: 'stop',
463+
usage: { completionTokens: 10, promptTokens: 3 },
464+
},
465+
]),
466+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
467+
};
468+
},
469+
}),
470+
schema: z.object({ content: z.string() }),
471+
mode: 'json',
472+
prompt: 'prompt',
473+
onFinish: async event => {
474+
result = event as unknown as typeof result;
475+
},
476+
});
477+
478+
// consume stream
479+
await convertAsyncIterableToArray(partialObjectStream);
480+
481+
// consume expected error rejection
482+
await object.catch(() => {});
483+
});
484+
485+
it('should contain token usage', async () => {
486+
assert.deepStrictEqual(result.usage, {
487+
completionTokens: 10,
488+
promptTokens: 3,
489+
totalTokens: 13,
490+
});
491+
});
492+
493+
it('should not contain a full object', async () => {
494+
assert.deepStrictEqual(result.object, undefined);
495+
});
496+
497+
it('should contain an error object', async () => {
498+
assert.deepStrictEqual(
499+
TypeValidationError.isTypeValidationError(result.error),
500+
true,
501+
);
502+
});
503+
});
504+
});

‎packages/core/core/generate-object/stream-object.ts

+65-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export async function streamObject<T>({
6767
messages,
6868
maxRetries,
6969
abortSignal,
70+
onFinish,
7071
...settings
7172
}: CallSettings &
7273
Prompt & {
@@ -95,6 +96,41 @@ Please note that most providers do not support all modes.
9596
Default and recommended: 'auto' (best mode for the model).
9697
*/
9798
mode?: 'auto' | 'json' | 'tool' | 'grammar';
99+
100+
/**
101+
Callback that is called when the LLM response and the final object validation are finished.
102+
*/
103+
onFinish?: (event: {
104+
/**
105+
The token usage of the generated response.
106+
*/
107+
usage: TokenUsage;
108+
109+
/**
110+
The generated object (typed according to the schema). Can be undefined if the final object does not match the schema.
111+
*/
112+
object: T | undefined;
113+
114+
/**
115+
Optional error object. This is e.g. a TypeValidationError when the final object does not match the schema.
116+
*/
117+
error: unknown | undefined;
118+
119+
/**
120+
Optional raw response data.
121+
*/
122+
rawResponse?: {
123+
/**
124+
Response headers.
125+
*/
126+
headers?: Record<string, string>;
127+
};
128+
129+
/**
130+
Warnings from the model provider (e.g. unsupported settings).
131+
*/
132+
warnings?: CallWarning[];
133+
}) => Promise<void> | void;
98134
}): Promise<StreamObjectResult<T>> {
99135
const retry = retryWithExponentialBackoff({ maxRetries });
100136
const jsonSchema = convertZodToJSONSchema(schema);
@@ -229,6 +265,7 @@ Default and recommended: 'auto' (best mode for the model).
229265
warnings: result.warnings,
230266
rawResponse: result.rawResponse,
231267
schema,
268+
onFinish,
232269
});
233270
}
234271

@@ -291,13 +328,15 @@ Response headers.
291328
warnings,
292329
rawResponse,
293330
schema,
331+
onFinish,
294332
}: {
295333
stream: ReadableStream<string | ObjectStreamInputPart>;
296334
warnings: CallWarning[] | undefined;
297335
rawResponse?: {
298336
headers?: Record<string, string>;
299337
};
300338
schema: z.Schema<T>;
339+
onFinish: Parameters<typeof streamObject<T>>[0]['onFinish'];
301340
}) {
302341
this.warnings = warnings;
303342
this.rawResponse = rawResponse;
@@ -318,6 +357,8 @@ Response headers.
318357

319358
// store information for onFinish callback:
320359
let usage: TokenUsage | undefined;
360+
let object: T | undefined;
361+
let error: unknown | undefined;
321362

322363
// pipe chunks through a transformation stream that extracts metadata:
323364
let accumulatedText = '';
@@ -360,9 +401,11 @@ Response headers.
360401
});
361402

362403
if (validationResult.success) {
363-
resolveObject(validationResult.value);
404+
object = validationResult.value;
405+
resolveObject(object);
364406
} else {
365-
rejectObject(validationResult.error);
407+
error = validationResult.error;
408+
rejectObject(error);
366409
}
367410

368411
break;
@@ -374,6 +417,26 @@ Response headers.
374417
}
375418
}
376419
},
420+
421+
// invoke onFinish callback and resolve toolResults promise when the stream is about to close:
422+
async flush(controller) {
423+
try {
424+
// call onFinish callback:
425+
await onFinish?.({
426+
usage: usage ?? {
427+
promptTokens: NaN,
428+
completionTokens: NaN,
429+
totalTokens: NaN,
430+
},
431+
object,
432+
error,
433+
rawResponse,
434+
warnings,
435+
});
436+
} catch (error) {
437+
controller.error(error);
438+
}
439+
},
377440
}),
378441
);
379442
}

0 commit comments

Comments
 (0)
Please sign in to comment.