Skip to content

Commit 2c8ffdb

Browse files
knajjarstomtobacMaxLeiter
authoredFeb 9, 2024
feat(cohere): support AsyncIterable for Cohere (#871)
Co-authored-by: Tomeu <tomeu@cohere.com> Co-authored-by: Max Leiter <max.leiter@vercel.com>
1 parent 4d0f5f1 commit 2c8ffdb

File tree

6 files changed

+131
-20
lines changed

6 files changed

+131
-20
lines changed
 

‎.changeset/eighty-days-turn.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
cohere-stream: support AsyncIterable

‎packages/core/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
"@vercel/ai-tsconfig": "workspace:*",
102102
"@vitejs/plugin-react": "4.2.0",
103103
"@vitejs/plugin-vue": "4.5.0",
104+
"cohere-ai": "^7.6.2",
104105
"eslint": "^7.32.0",
105106
"eslint-config-vercel-ai": "workspace:*",
106107
"jsdom": "^23.0.0",

‎packages/core/streams/cohere-stream.test.ts

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import { CohereClient } from 'cohere-ai';
12
import {
23
CohereStream,
34
StreamingTextResponse,
45
experimental_StreamData,
56
} from '.';
6-
import { cohereChunks } from '../tests/snapshots/cohere';
7+
import { cohereChatChunks, cohereChunks } from '../tests/snapshots/cohere';
78
import { readAllChunks } from '../tests/utils/mock-client';
89
import { DEFAULT_TEST_URL, createMockServer } from '../tests/utils/mock-server';
910

1011
const server = createMockServer([
12+
{
13+
url: 'https://api.cohere.ai/v1/chat',
14+
chunks: cohereChatChunks,
15+
formatChunk: chunk => `${JSON.stringify(chunk)}\n`,
16+
},
1117
{
1218
url: DEFAULT_TEST_URL,
1319
chunks: cohereChunks,
@@ -28,9 +34,27 @@ describe('CohereStream', () => {
2834
server.close();
2935
});
3036

37+
it('should be able to parse Chat Streaming API and receive the streamed response', async () => {
38+
const co = new CohereClient({
39+
token: 'cohere-token',
40+
});
41+
const cohereResponse = await co.chatStream({
42+
message: 'hi there!',
43+
});
44+
45+
const stream = CohereStream(cohereResponse);
46+
const response = new StreamingTextResponse(stream);
47+
expect(await readAllChunks(response)).toEqual([
48+
' Hello',
49+
',',
50+
' world',
51+
'.',
52+
' ',
53+
]);
54+
});
55+
3156
it('should be able to parse SSE and receive the streamed response', async () => {
3257
const stream = CohereStream(await fetch(DEFAULT_TEST_URL));
33-
3458
const response = new StreamingTextResponse(stream);
3559

3660
expect(await readAllChunks(response)).toEqual([

‎packages/core/streams/cohere-stream.ts

+37-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
import {
22
type AIStreamCallbacksAndOptions,
33
createCallbacksTransformer,
4+
readableFromAsyncIterable,
45
} from './ai-stream';
56
import { createStreamDataTransformer } from './stream-data';
67

78
const utf8Decoder = new TextDecoder('utf-8');
89

10+
// Full types
11+
// @see: https://github.com/cohere-ai/cohere-typescript/blob/c2eceb4a845098240ba0bc44e3787ccf75e268e8/src/api/types/StreamedChatResponse.ts
12+
interface StreamChunk {
13+
text?: string;
14+
eventType:
15+
| 'stream-start'
16+
| 'search-queries-generation'
17+
| 'search-results'
18+
| 'text-generation'
19+
| 'citation-generation'
20+
| 'stream-end';
21+
}
22+
923
async function processLines(
1024
lines: string[],
1125
controller: ReadableStreamDefaultController<string>,
@@ -63,13 +77,30 @@ function createParser(res: Response) {
6377
});
6478
}
6579

80+
async function* streamable(stream: AsyncIterable<StreamChunk>) {
81+
for await (const chunk of stream) {
82+
if (chunk.eventType === 'text-generation') {
83+
const text = chunk.text;
84+
if (text) yield text;
85+
}
86+
}
87+
}
88+
6689
export function CohereStream(
67-
reader: Response,
90+
reader: Response | AsyncIterable<StreamChunk>,
6891
callbacks?: AIStreamCallbacksAndOptions,
6992
): ReadableStream {
70-
return createParser(reader)
71-
.pipeThrough(createCallbacksTransformer(callbacks))
72-
.pipeThrough(
73-
createStreamDataTransformer(callbacks?.experimental_streamData),
74-
);
93+
if (Symbol.asyncIterator in reader) {
94+
return readableFromAsyncIterable(streamable(reader))
95+
.pipeThrough(createCallbacksTransformer(callbacks))
96+
.pipeThrough(
97+
createStreamDataTransformer(callbacks?.experimental_streamData),
98+
);
99+
} else {
100+
return createParser(reader)
101+
.pipeThrough(createCallbacksTransformer(callbacks))
102+
.pipeThrough(
103+
createStreamDataTransformer(callbacks?.experimental_streamData),
104+
);
105+
}
75106
}

‎packages/core/tests/snapshots/cohere.ts

+24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
export const cohereChatChunks = [
2+
{ text: ' Hello', is_finished: false, event_type: 'text-generation' },
3+
{ text: ',', is_finished: false, event_type: 'text-generation' },
4+
{ text: ' world', is_finished: false, event_type: 'text-generation' },
5+
{ text: '.', is_finished: false, event_type: 'text-generation' },
6+
{ text: ' ', is_finished: false, event_type: 'text-generation' },
7+
{
8+
is_finished: true,
9+
event_type: 'text-generation',
10+
finish_reason: 'COMPLETE',
11+
response: {
12+
id: 'a15fdf82-758e-4d35-a520-40fbb7308631',
13+
generations: [
14+
{
15+
id: '9700d8ef-a8e5-4eb8-8288-3f0e0a1daed5',
16+
text: ' Hello, world. ',
17+
finish_reason: 'COMPLETE',
18+
},
19+
],
20+
prompt: 'Hello',
21+
},
22+
},
23+
];
24+
125
export const cohereChunks = [
226
{ text: ' Hello', is_finished: false },
327
{ text: ',', is_finished: false },

‎pnpm-lock.yaml

+38-12
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
Please sign in to comment.