Skip to content

Commit 102ca22

Browse files
authoredJun 6, 2024··
feat (core): add object promise to streamObject result (#1858)
1 parent eea33ae commit 102ca22

File tree

13 files changed

+286
-44
lines changed

13 files changed

+286
-44
lines changed
 

Diff for: ‎.changeset/metal-ducks-compete.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@ai-sdk/provider': patch
3+
---
4+
5+
fix (@ai-sdk/provider): fix TypeValidationError.isTypeValidationError

Diff for: ‎.changeset/orange-bananas-roll.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (core): add object promise to streamObject result

Diff for: ‎content/docs/03-ai-sdk-core/10-generating-structured-data.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ While some models (like OpenAI) natively support object generation, others requi
4747
- `json`: The JSON schema and an instruction is injected into the prompt. If the provider supports JSON mode, it is enabled.
4848
- `grammar`: The provider is instructed to convert the JSON schema into a provider specific grammar and use it to select the output tokens.
4949

50-
<Note>Please note that most providers do not support all modes.</Note>
50+
<Note>Please note that not every provider supports all generation modes.</Note>
5151

5252
## Streaming Objects
5353

Diff for: ‎content/docs/07-reference/ai-sdk-core/04-stream-object.mdx

+10
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,12 @@ for await (const partialObject of partialObjectStream) {
342342
},
343343
],
344344
},
345+
{
346+
name: 'object',
347+
type: 'Promise<T>',
348+
description:
349+
'The generated object (typed according to the schema). Resolved when the response is finished.',
350+
},
345351
{
346352
name: 'partialObjectStream',
347353
type: 'AsyncIterableStream<DeepPartial<T>>',
@@ -450,5 +456,9 @@ for await (const partialObject of partialObjectStream) {
450456
title: 'Recording Token Usage',
451457
link: '/examples/node/streaming-structured-data/token-usage',
452458
},
459+
{
460+
title: 'Recording Final Object',
461+
link: '/examples/node/streaming-structured-data/object',
462+
},
453463
]}
454464
/>

Diff for: ‎content/examples/03-node/02-streaming-structured-data/10-token-usage.mdx

+4
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@ result.usage.then(recordTokenUsage);
4242

4343
// use with async/await:
4444
recordTokenUsage(await result.usage);
45+
46+
// note: the stream needs to be consumed because of backpressure
47+
for await (const partialObject of result.partialObjectStream) {
48+
}
4549
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
---
2+
title: Recording Final Object
3+
description: Examples of how to record the final object when streaming structured data.
4+
---
5+
6+
# Recording Final Object
7+
8+
When you're streaming structured data, you may want to record the final object for logging or other purposes.
9+
10+
The [`streamObject`](/docs/reference/ai-sdk-core/stream-object) result contains an `object` promise that resolves to the final object.
11+
The object is fully typed. When the type validation according to the schema fails, the promise will be rejected with a `TypeValidationError`.
12+
13+
```ts file='index.ts' highlight={"17-26"}
14+
import { openai } from '@ai-sdk/openai';
15+
import { streamObject, TokenUsage } from 'ai';
16+
import { z } from 'zod';
17+
18+
const result = await streamObject({
19+
model: openai('gpt-4-turbo'),
20+
schema: z.object({
21+
recipe: z.object({
22+
name: z.string(),
23+
ingredients: z.array(z.string()),
24+
steps: z.array(z.string()),
25+
}),
26+
}),
27+
prompt: 'Generate a lasagna recipe.',
28+
});
29+
30+
result.object
31+
.then(({ recipe }) => {
32+
// do something with the fully typed, final object:
33+
console.log('Recipe:', JSON.stringify(recipe, null, 2));
34+
})
35+
.catch(error => {
36+
// handle type validation failure
37+
// (when the object does not match the schema):
38+
console.error(error);
39+
});
40+
41+
// note: the stream needs to be consumed because of backpressure
42+
for await (const partialObject of result.partialObjectStream) {
43+
}
44+
```

Diff for: ‎content/examples/03-node/02-streaming-structured-data/index.mdx

+4
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,9 @@ The following sections will guide you through streaming structured data with Nod
1717
title: 'Recording Token Usage',
1818
href: '/examples/node/streaming-structured-data/token-usage',
1919
},
20+
{
21+
title: 'Recording Final Object',
22+
href: '/examples/node/streaming-structured-data/object',
23+
},
2024
]}
2125
/>

Diff for: ‎examples/ai-core/src/stream-object/openai-object.ts

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { streamObject } from 'ai';
3+
import dotenv from 'dotenv';
4+
import { z } from 'zod';
5+
6+
dotenv.config();
7+
8+
async function main() {
9+
const result = await streamObject({
10+
model: openai('gpt-4-turbo'),
11+
schema: z.object({
12+
recipe: z.object({
13+
name: z.string(),
14+
ingredients: z.array(z.string()),
15+
steps: z.array(z.string()),
16+
}),
17+
}),
18+
prompt: 'Generate a lasagna recipe.',
19+
});
20+
21+
result.object
22+
.then(({ recipe }) => {
23+
// do something with the fully typed, final object:
24+
console.log('Recipe:', JSON.stringify(recipe, null, 2));
25+
})
26+
.catch(error => {
27+
// handle type validation failure
28+
// (when the object does not match the schema):
29+
console.error(error);
30+
});
31+
32+
// note: the stream needs to be consumed because of backpressure
33+
for await (const partialObject of result.partialObjectStream) {
34+
}
35+
}
36+
37+
main().catch(console.error);
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { streamObject, TokenUsage } from 'ai';
3+
import dotenv from 'dotenv';
4+
import { z } from 'zod';
5+
6+
dotenv.config();
7+
8+
async function main() {
9+
const result = await streamObject({
10+
model: openai('gpt-4-turbo'),
11+
schema: z.object({
12+
recipe: z.object({
13+
name: z.string(),
14+
ingredients: z.array(z.string()),
15+
steps: z.array(z.string()),
16+
}),
17+
}),
18+
prompt: 'Generate a lasagna recipe.',
19+
});
20+
21+
// your custom function to record token usage:
22+
function recordTokenUsage({
23+
promptTokens,
24+
completionTokens,
25+
totalTokens,
26+
}: TokenUsage) {
27+
console.log('Prompt tokens:', promptTokens);
28+
console.log('Completion tokens:', completionTokens);
29+
console.log('Total tokens:', totalTokens);
30+
}
31+
32+
// use as promise:
33+
result.usage.then(recordTokenUsage);
34+
35+
// use with async/await:
36+
recordTokenUsage(await result.usage);
37+
38+
// note: the stream needs to be consumed because of backpressure
39+
for await (const partialObject of result.partialObjectStream) {
40+
}
41+
}
42+
43+
main().catch(console.error);

Diff for: ‎examples/ai-core/src/stream-object/token-usage.ts

-32
This file was deleted.

Diff for: ‎packages/core/core/generate-object/stream-object.test.ts

+101
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { convertArrayToReadableStream } from '../test/convert-array-to-readable-
44
import { convertAsyncIterableToArray } from '../test/convert-async-iterable-to-array';
55
import { MockLanguageModelV1 } from '../test/mock-language-model-v1';
66
import { streamObject } from './stream-object';
7+
import { TypeValidationError } from '@ai-sdk/provider';
78

89
describe('result.objectStream', () => {
910
it('should send object deltas with json mode', async () => {
@@ -255,3 +256,103 @@ describe('result.usage', () => {
255256
});
256257
});
257258
});
259+
260+
describe('result.object', () => {
261+
it('should resolve with typed object', async () => {
262+
const result = await streamObject({
263+
model: new MockLanguageModelV1({
264+
doStream: async ({ prompt, mode }) => {
265+
assert.deepStrictEqual(mode, { type: 'object-json' });
266+
assert.deepStrictEqual(prompt, [
267+
{
268+
role: 'system',
269+
content:
270+
'JSON schema:\n' +
271+
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
272+
'You MUST answer with a JSON object that matches the JSON schema above.',
273+
},
274+
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
275+
]);
276+
277+
return {
278+
stream: convertArrayToReadableStream([
279+
{ type: 'text-delta', textDelta: '{ ' },
280+
{ type: 'text-delta', textDelta: '"content": ' },
281+
{ type: 'text-delta', textDelta: `"Hello, ` },
282+
{ type: 'text-delta', textDelta: `world` },
283+
{ type: 'text-delta', textDelta: `!"` },
284+
{ type: 'text-delta', textDelta: ' }' },
285+
{
286+
type: 'finish',
287+
finishReason: 'stop',
288+
usage: { completionTokens: 10, promptTokens: 3 },
289+
},
290+
]),
291+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
292+
};
293+
},
294+
}),
295+
schema: z.object({ content: z.string() }),
296+
mode: 'json',
297+
prompt: 'prompt',
298+
});
299+
300+
// consume stream (runs in parallel)
301+
convertAsyncIterableToArray(result.partialObjectStream);
302+
303+
assert.deepStrictEqual(await result.object, {
304+
content: 'Hello, world!',
305+
});
306+
});
307+
308+
it('should reject object promise when the streamed object does not match the schema', async () => {
309+
const result = await streamObject({
310+
model: new MockLanguageModelV1({
311+
doStream: async ({ prompt, mode }) => {
312+
assert.deepStrictEqual(mode, { type: 'object-json' });
313+
assert.deepStrictEqual(prompt, [
314+
{
315+
role: 'system',
316+
content:
317+
'JSON schema:\n' +
318+
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
319+
'You MUST answer with a JSON object that matches the JSON schema above.',
320+
},
321+
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
322+
]);
323+
324+
return {
325+
stream: convertArrayToReadableStream([
326+
{ type: 'text-delta', textDelta: '{ ' },
327+
{ type: 'text-delta', textDelta: '"invalid": ' },
328+
{ type: 'text-delta', textDelta: `"Hello, ` },
329+
{ type: 'text-delta', textDelta: `world` },
330+
{ type: 'text-delta', textDelta: `!"` },
331+
{ type: 'text-delta', textDelta: ' }' },
332+
{
333+
type: 'finish',
334+
finishReason: 'stop',
335+
usage: { completionTokens: 10, promptTokens: 3 },
336+
},
337+
]),
338+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
339+
};
340+
},
341+
}),
342+
schema: z.object({ content: z.string() }),
343+
mode: 'json',
344+
prompt: 'prompt',
345+
});
346+
347+
// consume stream (runs in parallel)
348+
convertAsyncIterableToArray(result.partialObjectStream);
349+
350+
await result.object
351+
.then(() => {
352+
assert.fail('Expected object promise to be rejected');
353+
})
354+
.catch(error => {
355+
expect(TypeValidationError.isTypeValidationError(error)).toBeTruthy();
356+
});
357+
});
358+
});

Diff for: ‎packages/core/core/generate-object/stream-object.ts

+31-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { isDeepEqualData } from '../util/is-deep-equal-data';
2020
import { parsePartialJson } from '../util/parse-partial-json';
2121
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
2222
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
23+
import { safeValidateTypes } from '@ai-sdk/provider-utils';
2324

2425
/**
2526
Generate a structured, typed object for a given prompt and schema using a language model.
@@ -227,6 +228,7 @@ Default and recommended: 'auto' (best mode for the model).
227228
stream: result.stream.pipeThrough(new TransformStream(transformer)),
228229
warnings: result.warnings,
229230
rawResponse: result.rawResponse,
231+
schema,
230232
});
231233
}
232234

@@ -257,13 +259,18 @@ export type ObjectStreamPart<T> =
257259
The result of a `streamObject` call that contains the partial object stream and additional information.
258260
*/
259261
export class StreamObjectResult<T> {
260-
readonly originalStream: ReadableStream<ObjectStreamPart<T>>;
262+
private readonly originalStream: ReadableStream<ObjectStreamPart<T>>;
261263

262264
/**
263265
Warnings from the model provider (e.g. unsupported settings)
264266
*/
265267
readonly warnings: CallWarning[] | undefined;
266268

269+
/**
270+
The generated object (typed according to the schema). Resolved when the response is finished.
271+
*/
272+
readonly object: Promise<T>;
273+
267274
/**
268275
The token usage of the generated response. Resolved when the response is finished.
269276
*/
@@ -283,16 +290,26 @@ Response headers.
283290
stream,
284291
warnings,
285292
rawResponse,
293+
schema,
286294
}: {
287295
stream: ReadableStream<string | ObjectStreamInputPart>;
288296
warnings: CallWarning[] | undefined;
289297
rawResponse?: {
290298
headers?: Record<string, string>;
291299
};
300+
schema: z.Schema<T>;
292301
}) {
293302
this.warnings = warnings;
294303
this.rawResponse = rawResponse;
295304

305+
// initialize object promise
306+
let resolveObject: (value: T | PromiseLike<T>) => void;
307+
let rejectObject: (reason?: any) => void;
308+
this.object = new Promise<T>((resolve, reject) => {
309+
resolveObject = resolve;
310+
rejectObject = reject;
311+
});
312+
296313
// initialize usage promise
297314
let resolveUsage: (value: TokenUsage | PromiseLike<TokenUsage>) => void;
298315
this.usage = new Promise<TokenUsage>(resolve => {
@@ -331,14 +348,23 @@ Response headers.
331348
// store usage for promises and onFinish callback:
332349
usage = calculateTokenUsage(chunk.usage);
333350

334-
controller.enqueue({
335-
...chunk,
336-
usage,
337-
});
351+
controller.enqueue({ ...chunk, usage });
338352

339353
// resolve promises that can be resolved now:
340354
resolveUsage(usage);
341355

356+
// resolve the object promise with the latest object:
357+
const validationResult = safeValidateTypes({
358+
value: latestObject,
359+
schema,
360+
});
361+
362+
if (validationResult.success) {
363+
resolveObject(validationResult.value);
364+
} else {
365+
rejectObject(validationResult.error);
366+
}
367+
342368
break;
343369
}
344370

Diff for: ‎packages/provider/src/errors/type-validation-error.ts

+1-6
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@ export class TypeValidationError extends Error {
1818
}
1919

2020
static isTypeValidationError(error: unknown): error is TypeValidationError {
21-
return (
22-
error instanceof Error &&
23-
error.name === 'AI_TypeValidationError' &&
24-
typeof (error as TypeValidationError).value === 'string' &&
25-
typeof (error as TypeValidationError).cause === 'string'
26-
);
21+
return error instanceof Error && error.name === 'AI_TypeValidationError';
2722
}
2823

2924
toJSON() {

0 commit comments

Comments
 (0)
Please sign in to comment.