Skip to content

Commit

Permalink
fix(ChatCompletionStream): abort on async iterator break and handle e…
Browse files Browse the repository at this point in the history
…rrors (#699)

`break`-ing the async iterator did not previously abort the request which increases usage.
Errors are now handled more effectively in the async iterator.
  • Loading branch information
stainless-bot committed Mar 1, 2024
1 parent 64041fd commit ac417a2
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 7 deletions.
53 changes: 52 additions & 1 deletion src/lib/ChatCompletionRunFunctions.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import OpenAI from 'openai';
import { OpenAIError } from 'openai/error';
import { OpenAIError, APIConnectionError } from 'openai/error';
import { PassThrough } from 'stream';
import {
ParsingToolFunction,
Expand Down Expand Up @@ -2207,6 +2207,7 @@ describe('resource completions', () => {
await listener.sanityCheck();
});
});

describe('stream', () => {
test('successful flow', async () => {
const { fetch, handleRequest } = mockStreamingChatCompletionFetch();
Expand Down Expand Up @@ -2273,5 +2274,55 @@ describe('resource completions', () => {
expect(listener.finalMessage).toEqual({ role: 'assistant', content: 'The weather is great today!' });
await listener.sanityCheck();
});
test('handles network errors', async () => {
const { fetch, handleRequest } = mockFetch();

const openai = new OpenAI({ apiKey: '...', fetch });

const stream = openai.beta.chat.completions.stream(
{
max_tokens: 1024,
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: 'Say hello there!' }],
},
{ maxRetries: 0 },
);

handleRequest(async () => {
throw new Error('mock request error');
}).catch(() => {});

async function runStream() {
await stream.done();
}

await expect(runStream).rejects.toThrow(APIConnectionError);
});
test('handles network errors on async iterator', async () => {
const { fetch, handleRequest } = mockFetch();

const openai = new OpenAI({ apiKey: '...', fetch });

const stream = openai.beta.chat.completions.stream(
{
max_tokens: 1024,
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: 'Say hello there!' }],
},
{ maxRetries: 0 },
);

handleRequest(async () => {
throw new Error('mock request error');
}).catch(() => {});

async function runStream() {
for await (const _event of stream) {
continue;
}
}

await expect(runStream).rejects.toThrow(APIConnectionError);
});
});
});
35 changes: 29 additions & 6 deletions src/lib/ChatCompletionStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,16 @@ export class ChatCompletionStream

[Symbol.asyncIterator](): AsyncIterator<ChatCompletionChunk> {
const pushQueue: ChatCompletionChunk[] = [];
const readQueue: ((chunk: ChatCompletionChunk | undefined) => void)[] = [];
const readQueue: {
resolve: (chunk: ChatCompletionChunk | undefined) => void;
reject: (err: unknown) => void;
}[] = [];
let done = false;

this.on('chunk', (chunk) => {
const reader = readQueue.shift();
if (reader) {
reader(chunk);
reader.resolve(chunk);
} else {
pushQueue.push(chunk);
}
Expand All @@ -225,7 +228,23 @@ export class ChatCompletionStream
this.on('end', () => {
done = true;
for (const reader of readQueue) {
reader(undefined);
reader.resolve(undefined);
}
readQueue.length = 0;
});

this.on('abort', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});

this.on('error', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});
Expand All @@ -236,13 +255,17 @@ export class ChatCompletionStream
if (done) {
return { value: undefined, done: true };
}
return new Promise<ChatCompletionChunk | undefined>((resolve) => readQueue.push(resolve)).then(
(chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }),
);
return new Promise<ChatCompletionChunk | undefined>((resolve, reject) =>
readQueue.push({ resolve, reject }),
).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }));
}
const chunk = pushQueue.shift()!;
return { value: chunk, done: false };
},
return: async () => {
this.abort();
return { value: undefined, done: true };
},
};
}

Expand Down

0 comments on commit ac417a2

Please sign in to comment.