Skip to content

Commit 68d1f78

Browse files
authoredJul 18, 2024··
fix (ai/core): do not construct object promise in streamObject result until requested (#2321)
1 parent bfc1a67 commit 68d1f78

File tree

5 files changed

+147
-37
lines changed

5 files changed

+147
-37
lines changed
 

‎.changeset/long-crews-watch.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
fix (ai/core): do not construct object promise in streamObject result until requested

‎packages/core/core/generate-object/stream-object.test.ts

+33-24
Original file line numberDiff line numberDiff line change
@@ -423,18 +423,6 @@ describe('result.object', () => {
423423
const result = await streamObject({
424424
model: new MockLanguageModelV1({
425425
doStream: async ({ prompt, mode }) => {
426-
assert.deepStrictEqual(mode, { type: 'object-json' });
427-
assert.deepStrictEqual(prompt, [
428-
{
429-
role: 'system',
430-
content:
431-
'JSON schema:\n' +
432-
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
433-
'You MUST answer with a JSON object that matches the JSON schema above.',
434-
},
435-
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
436-
]);
437-
438426
return {
439427
stream: convertArrayToReadableStream([
440428
{ type: 'text-delta', textDelta: '{ ' },
@@ -470,18 +458,6 @@ describe('result.object', () => {
470458
const result = await streamObject({
471459
model: new MockLanguageModelV1({
472460
doStream: async ({ prompt, mode }) => {
473-
assert.deepStrictEqual(mode, { type: 'object-json' });
474-
assert.deepStrictEqual(prompt, [
475-
{
476-
role: 'system',
477-
content:
478-
'JSON schema:\n' +
479-
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
480-
'You MUST answer with a JSON object that matches the JSON schema above.',
481-
},
482-
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
483-
]);
484-
485461
return {
486462
stream: convertArrayToReadableStream([
487463
{ type: 'text-delta', textDelta: '{ ' },
@@ -516,6 +492,39 @@ describe('result.object', () => {
516492
expect(TypeValidationError.isTypeValidationError(error)).toBeTruthy();
517493
});
518494
});
495+
496+
it('should not lead to unhandled promise rejections when the streamed object does not match the schema', async () => {
497+
const result = await streamObject({
498+
model: new MockLanguageModelV1({
499+
doStream: async ({ prompt, mode }) => {
500+
return {
501+
stream: convertArrayToReadableStream([
502+
{ type: 'text-delta', textDelta: '{ ' },
503+
{ type: 'text-delta', textDelta: '"invalid": ' },
504+
{ type: 'text-delta', textDelta: `"Hello, ` },
505+
{ type: 'text-delta', textDelta: `world` },
506+
{ type: 'text-delta', textDelta: `!"` },
507+
{ type: 'text-delta', textDelta: ' }' },
508+
{
509+
type: 'finish',
510+
finishReason: 'stop',
511+
usage: { completionTokens: 10, promptTokens: 3 },
512+
},
513+
]),
514+
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
515+
};
516+
},
517+
}),
518+
schema: z.object({ content: z.string() }),
519+
mode: 'json',
520+
prompt: 'prompt',
521+
});
522+
523+
// consume stream (runs in parallel)
524+
convertAsyncIterableToArray(result.partialObjectStream);
525+
526+
// unhandled promise rejection should not be thrown (Vitest does this automatically)
527+
});
519528
});
520529

521530
describe('options.onFinish', () => {

‎packages/core/core/generate-object/stream-object.ts

+13-13
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import { convertZodToJSONSchema } from '../util/convert-zod-to-json-schema';
2828
import { prepareResponseHeaders } from '../util/prepare-response-headers';
2929
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
3030
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
31+
import { DelayedPromise } from '../util/delayed-promise';
3132

3233
/**
3334
Generate a structured, typed object for a given prompt and schema using a language model.
@@ -316,17 +317,13 @@ The result of a `streamObject` call that contains the partial object stream and
316317
*/
317318
export class StreamObjectResult<T> {
318319
private readonly originalStream: ReadableStream<ObjectStreamPart<T>>;
320+
private readonly objectPromise: DelayedPromise<T>;
319321

320322
/**
321323
Warnings from the model provider (e.g. unsupported settings)
322324
*/
323325
readonly warnings: CallWarning[] | undefined;
324326

325-
/**
326-
The generated object (typed according to the schema). Resolved when the response is finished.
327-
*/
328-
readonly object: Promise<T>;
329-
330327
/**
331328
The token usage of the generated response. Resolved when the response is finished.
332329
*/
@@ -363,12 +360,7 @@ Response headers.
363360
this.rawResponse = rawResponse;
364361

365362
// initialize object promise
366-
let resolveObject: (value: T | PromiseLike<T>) => void;
367-
let rejectObject: (reason?: any) => void;
368-
this.object = new Promise<T>((resolve, reject) => {
369-
resolveObject = resolve;
370-
rejectObject = reject;
371-
});
363+
this.objectPromise = new DelayedPromise<T>();
372364

373365
// initialize usage promise
374366
let resolveUsage: (
@@ -388,6 +380,7 @@ Response headers.
388380
let delta = '';
389381
let latestObject: DeepPartial<T> | undefined = undefined;
390382

383+
const self = this;
391384
this.originalStream = stream.pipeThrough(
392385
new TransformStream<string | ObjectStreamInputPart, ObjectStreamPart<T>>({
393386
async transform(chunk, controller): Promise<void> {
@@ -445,10 +438,10 @@ Response headers.
445438

446439
if (validationResult.success) {
447440
object = validationResult.value;
448-
resolveObject(object);
441+
self.objectPromise.resolve(object);
449442
} else {
450443
error = validationResult.error;
451-
rejectObject(error);
444+
self.objectPromise.reject(error);
452445
}
453446

454447
break;
@@ -484,6 +477,13 @@ Response headers.
484477
);
485478
}
486479

480+
/**
481+
The generated object (typed according to the schema). Resolved when the response is finished.
482+
*/
483+
get object(): Promise<T> {
484+
return this.objectPromise.value;
485+
}
486+
487487
/**
488488
Stream of partial objects. It gets more complete as the stream progresses.
489489
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { expect, it, describe } from 'vitest';
2+
import { DelayedPromise } from './delayed-promise';
3+
4+
describe('DelayedPromise', () => {
5+
it('should resolve when accessed after resolution', async () => {
6+
const dp = new DelayedPromise<string>();
7+
dp.resolve('success');
8+
expect(await dp.value).toBe('success');
9+
});
10+
11+
it('should reject when accessed after rejection', async () => {
12+
const dp = new DelayedPromise<string>();
13+
const error = new Error('failure');
14+
dp.reject(error);
15+
await expect(dp.value).rejects.toThrow('failure');
16+
});
17+
18+
it('should resolve when accessed before resolution', async () => {
19+
const dp = new DelayedPromise<string>();
20+
const promise = dp.value;
21+
dp.resolve('success');
22+
expect(await promise).toBe('success');
23+
});
24+
25+
it('should reject when accessed before rejection', async () => {
26+
const dp = new DelayedPromise<string>();
27+
const promise = dp.value;
28+
const error = new Error('failure');
29+
dp.reject(error);
30+
await expect(promise).rejects.toThrow('failure');
31+
});
32+
33+
it('should maintain the resolved state after multiple accesses', async () => {
34+
const dp = new DelayedPromise<string>();
35+
dp.resolve('success');
36+
expect(await dp.value).toBe('success');
37+
expect(await dp.value).toBe('success');
38+
});
39+
40+
it('should maintain the rejected state after multiple accesses', async () => {
41+
const dp = new DelayedPromise<string>();
42+
const error = new Error('failure');
43+
dp.reject(error);
44+
await expect(dp.value).rejects.toThrow('failure');
45+
await expect(dp.value).rejects.toThrow('failure');
46+
});
47+
});
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* Delayed promise. It is only constructed once the value is accessed.
3+
* This is useful to avoid unhandled promise rejections when the promise is created
4+
* but not accessed.
5+
*/
6+
export class DelayedPromise<T> {
7+
private status:
8+
| { type: 'pending' }
9+
| { type: 'resolved'; value: T }
10+
| { type: 'rejected'; error: unknown } = { type: 'pending' };
11+
private promise: Promise<T> | undefined;
12+
private _resolve: undefined | ((value: T) => void) = undefined;
13+
private _reject: undefined | ((error: unknown) => void) = undefined;
14+
15+
get value(): Promise<T> {
16+
if (this.promise) {
17+
return this.promise;
18+
}
19+
20+
this.promise = new Promise<T>((resolve, reject) => {
21+
if (this.status.type === 'resolved') {
22+
resolve(this.status.value);
23+
} else if (this.status.type === 'rejected') {
24+
reject(this.status.error);
25+
}
26+
27+
this._resolve = resolve;
28+
this._reject = reject;
29+
});
30+
31+
return this.promise;
32+
}
33+
34+
resolve(value: T): void {
35+
this.status = { type: 'resolved', value };
36+
37+
if (this.promise) {
38+
this._resolve?.(value);
39+
}
40+
}
41+
42+
reject(error: unknown): void {
43+
this.status = { type: 'rejected', error };
44+
45+
if (this.promise) {
46+
this._reject?.(error);
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)
Please sign in to comment.