Skip to content

Commit 3e2299e

Browse files
lgrammelMaxLeiter
andauthoredNov 15, 2023
Add types for data stream lines. (#741)
Co-authored-by: Max Leiter <max.leiter@vercel.com>
1 parent a176761 commit 3e2299e

16 files changed

+418
-208
lines changed
 

‎.changeset/great-boxes-glow.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
experimental_StreamData/StreamingReactResponse: optimize parsing, improve types

‎examples/next-openai/app/stream-react-response/action.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ export async function handler({ messages }: { messages: Message[] }) {
109109
data,
110110
ui({ content, data }) {
111111
if (data != null) {
112-
const value = (data as JSONValue[])[0] as any;
112+
const value = data[0] as any;
113113

114114
switch (value.type) {
115115
case 'weather': {

‎packages/core/react/parse-complex-response.test.ts

+101-31
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ describe('parseComplexResponse function', () => {
2929
reader: createTestReader(['0:"Hello"\n']),
3030
abortControllerRef: { current: new AbortController() },
3131
update: mockUpdate,
32+
generateId: () => 'test-id',
33+
getCurrentDate: () => new Date(0),
3234
});
3335

3436
const expectedMessage = assistantTextMessage('Hello');
@@ -37,9 +39,18 @@ describe('parseComplexResponse function', () => {
3739
expect(mockUpdate).toHaveBeenCalledTimes(1);
3840
expect(mockUpdate.mock.calls[0][0]).toEqual([expectedMessage]);
3941

40-
// check the prefix map:
41-
expect(result).toHaveProperty('text');
42-
expect(result.text).toEqual(expectedMessage);
42+
// check the result
43+
expect(result).toEqual({
44+
messages: [
45+
{
46+
content: 'Hello',
47+
createdAt: new Date(0),
48+
id: 'test-id',
49+
role: 'assistant',
50+
},
51+
],
52+
data: [],
53+
});
4354
});
4455

4556
it('should parse a sequence of text messages', async () => {
@@ -54,6 +65,8 @@ describe('parseComplexResponse function', () => {
5465
]),
5566
abortControllerRef: { current: new AbortController() },
5667
update: mockUpdate,
68+
generateId: () => 'test-id',
69+
getCurrentDate: () => new Date(0),
5770
});
5871

5972
// check the mockUpdate call:
@@ -71,20 +84,31 @@ describe('parseComplexResponse function', () => {
7184
assistantTextMessage('Hello, world.'),
7285
]);
7386

74-
// check the prefix map:
75-
expect(result).toHaveProperty('text');
76-
expect(result.text).toEqual(assistantTextMessage('Hello, world.'));
87+
// check the result
88+
expect(result).toEqual({
89+
messages: [
90+
{
91+
content: 'Hello, world.',
92+
createdAt: new Date(0),
93+
id: 'test-id',
94+
role: 'assistant',
95+
},
96+
],
97+
data: [],
98+
});
7799
});
78100

79101
it('should parse a function call', async () => {
80102
const mockUpdate = jest.fn();
81103

82104
const result = await parseComplexResponse({
83105
reader: createTestReader([
84-
'1:"{\\"function_call\\": {\\"name\\": \\"get_current_weather\\", \\"arguments\\": \\"{\\\\n\\\\\\"location\\\\\\": \\\\\\"Charlottesville, Virginia\\\\\\",\\\\n\\\\\\"format\\\\\\": \\\\\\"celsius\\\\\\"\\\\n}\\"}}"\n',
106+
'1:{"function_call":{"name":"get_current_weather","arguments":"{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}\n',
85107
]),
86108
abortControllerRef: { current: new AbortController() },
87109
update: mockUpdate,
110+
generateId: () => 'test-id',
111+
getCurrentDate: () => new Date(0),
88112
});
89113

90114
// check the mockUpdate call:
@@ -104,21 +128,24 @@ describe('parseComplexResponse function', () => {
104128
},
105129
]);
106130

107-
// check the prefix map:
108-
expect(result.function_call).toEqual({
109-
id: expect.any(String),
110-
role: 'assistant',
111-
content: '',
112-
name: 'get_current_weather',
113-
function_call: {
114-
name: 'get_current_weather',
115-
arguments:
116-
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
117-
},
118-
createdAt: expect.any(Date),
131+
// check the result
132+
expect(result).toEqual({
133+
messages: [
134+
{
135+
content: '',
136+
createdAt: new Date(0),
137+
id: 'test-id',
138+
role: 'assistant',
139+
function_call: {
140+
name: 'get_current_weather',
141+
arguments:
142+
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
143+
},
144+
name: 'get_current_weather',
145+
},
146+
],
147+
data: [],
119148
});
120-
expect(result).not.toHaveProperty('text');
121-
expect(result).not.toHaveProperty('data');
122149
});
123150

124151
it('should parse a combination of a data and a text message', async () => {
@@ -127,31 +154,74 @@ describe('parseComplexResponse function', () => {
127154
// Execute the parser function
128155
const result = await parseComplexResponse({
129156
reader: createTestReader([
130-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
157+
'2:[{"t1":"v1"}]\n',
131158
'0:"Sample text message."\n',
132159
]),
133160
abortControllerRef: { current: new AbortController() },
134161
update: mockUpdate,
162+
generateId: () => 'test-id',
163+
getCurrentDate: () => new Date(0),
135164
});
136165

137-
const expectedData = [{ t1: 'v1' }];
138-
139166
// check the mockUpdate call:
140167
expect(mockUpdate).toHaveBeenCalledTimes(2);
141168

142169
expect(mockUpdate.mock.calls[0][0]).toEqual([]);
143-
expect(mockUpdate.mock.calls[0][1]).toEqual(expectedData);
170+
expect(mockUpdate.mock.calls[0][1]).toEqual([{ t1: 'v1' }]);
144171

145172
expect(mockUpdate.mock.calls[1][0]).toEqual([
146173
assistantTextMessage('Sample text message.'),
147174
]);
148-
expect(mockUpdate.mock.calls[1][1]).toEqual(expectedData);
175+
expect(mockUpdate.mock.calls[1][1]).toEqual([{ t1: 'v1' }]);
176+
177+
// check the result
178+
expect(result).toEqual({
179+
messages: [
180+
{
181+
content: 'Sample text message.',
182+
createdAt: new Date(0),
183+
id: 'test-id',
184+
role: 'assistant',
185+
},
186+
],
187+
data: [{ t1: 'v1' }],
188+
});
189+
});
149190

150-
// check the prefix map:
151-
expect(result).toHaveProperty('data');
152-
expect(result.data).toEqual(expectedData);
191+
it('should parse multiple data messages incl. primitive values', async () => {
192+
const mockUpdate = jest.fn();
153193

154-
expect(result).toHaveProperty('text');
155-
expect(result.text).toEqual(assistantTextMessage('Sample text message.'));
194+
// Execute the parser function
195+
const result = await parseComplexResponse({
196+
reader: createTestReader([
197+
'2:[{"t1":"v1"}, 3]\n',
198+
'2:[null,false,"text"]\n',
199+
]),
200+
abortControllerRef: { current: new AbortController() },
201+
update: mockUpdate,
202+
generateId: () => 'test-id',
203+
getCurrentDate: () => new Date(0),
204+
});
205+
206+
// check the mockUpdate call:
207+
expect(mockUpdate).toHaveBeenCalledTimes(2);
208+
209+
expect(mockUpdate.mock.calls[0][0]).toEqual([]);
210+
expect(mockUpdate.mock.calls[0][1]).toEqual([{ t1: 'v1' }, 3]);
211+
212+
expect(mockUpdate.mock.calls[1][0]).toEqual([]);
213+
expect(mockUpdate.mock.calls[1][1]).toEqual([
214+
{ t1: 'v1' },
215+
3,
216+
null,
217+
false,
218+
'text',
219+
]);
220+
221+
// check the result
222+
expect(result).toEqual({
223+
messages: [],
224+
data: [{ t1: 'v1' }, 3, null, false, 'text'],
225+
});
156226
});
157227
});

‎packages/core/react/parse-complex-response.ts

+35-38
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,40 @@
11
import type { FunctionCall, JSONValue, Message } from '../shared/types';
2-
import { nanoid, createChunkDecoder } from '../shared/utils';
2+
import { createChunkDecoder, nanoid } from '../shared/utils';
33

44
type PrefixMap = {
55
text?: Message;
6-
function_call?:
7-
| string
8-
| Pick<Message, 'function_call' | 'role' | 'content' | 'name'>;
9-
data?: JSONValue[];
6+
function_call?: Message & {
7+
role: 'assistant';
8+
function_call: FunctionCall;
9+
};
10+
data: JSONValue[];
1011
};
1112

1213
export async function parseComplexResponse({
1314
reader,
1415
abortControllerRef,
1516
update,
1617
onFinish,
18+
generateId = nanoid,
19+
getCurrentDate = () => new Date(),
1720
}: {
1821
reader: ReadableStreamDefaultReader<Uint8Array>;
1922
abortControllerRef?: {
2023
current: AbortController | null;
2124
};
2225
update: (merged: Message[], data: JSONValue[] | undefined) => void;
2326
onFinish?: (prefixMap: PrefixMap) => void;
27+
generateId?: () => string;
28+
getCurrentDate?: () => Date;
2429
}) {
30+
const createdAt = getCurrentDate();
31+
2532
const decode = createChunkDecoder(true);
26-
const createdAt = new Date();
27-
const prefixMap: PrefixMap = {};
33+
const prefixMap: PrefixMap = {
34+
data: [],
35+
};
2836
const NEWLINE = '\n'.charCodeAt(0);
29-
let chunks: Uint8Array[] = [];
37+
const chunks: Uint8Array[] = [];
3038
let totalLength = 0;
3139

3240
while (true) {
@@ -73,7 +81,7 @@ export async function parseComplexResponse({
7381
};
7482
} else {
7583
prefixMap['text'] = {
76-
id: nanoid(),
84+
id: generateId(),
7785
role: 'assistant',
7886
content: value,
7987
createdAt,
@@ -84,46 +92,30 @@ export async function parseComplexResponse({
8492
let functionCallMessage: Message | null = null;
8593

8694
if (type === 'function_call') {
87-
prefixMap['function_call'] = value;
88-
89-
let functionCall = prefixMap['function_call'];
90-
// Ensure it hasn't been parsed
91-
if (functionCall && typeof functionCall === 'string') {
92-
const parsedFunctionCall: FunctionCall = JSON.parse(
93-
functionCall as string,
94-
).function_call;
95-
96-
functionCallMessage = {
97-
id: nanoid(),
98-
role: 'assistant',
99-
content: '',
100-
function_call: parsedFunctionCall,
101-
name: parsedFunctionCall.name,
102-
createdAt,
103-
};
104-
105-
prefixMap['function_call'] = functionCallMessage as any;
106-
}
95+
prefixMap['function_call'] = {
96+
id: generateId(),
97+
role: 'assistant',
98+
content: '',
99+
function_call: value.function_call,
100+
name: value.function_call.name,
101+
createdAt,
102+
};
103+
104+
functionCallMessage = prefixMap['function_call'];
107105
}
108106

109107
if (type === 'data') {
110-
const parsedValue = JSON.parse(value);
111-
if (prefixMap['data']) {
112-
prefixMap['data'] = [...prefixMap['data'], ...parsedValue];
113-
} else {
114-
prefixMap['data'] = parsedValue;
115-
}
108+
prefixMap['data'].push(...value);
116109
}
117110

118-
const data = prefixMap['data'];
119111
const responseMessage = prefixMap['text'];
120112

121113
// We add function calls and response messages to the messages[], but data is its own thing
122114
const merged = [functionCallMessage, responseMessage].filter(
123115
Boolean,
124116
) as Message[];
125117

126-
update(merged, data);
118+
update(merged, [...prefixMap['data']]); // make a copy of the data array
127119

128120
// The request has been aborted, stop reading the stream.
129121
// If abortControllerRef is undefined, this is intentionally not executed.
@@ -136,5 +128,10 @@ export async function parseComplexResponse({
136128

137129
onFinish?.(prefixMap);
138130

139-
return prefixMap;
131+
return {
132+
messages: [prefixMap.text, prefixMap.function_call].filter(
133+
Boolean,
134+
) as Message[],
135+
data: prefixMap.data,
136+
};
140137
}

‎packages/core/react/use-chat.ts

+6-17
Original file line numberDiff line numberDiff line change
@@ -199,33 +199,22 @@ const getStreamedResponse = async (
199199
}
200200

201201
const isComplexMode = res.headers.get(COMPLEX_HEADER) === 'true';
202-
let responseMessages: Message[] = [];
203202
const reader = res.body.getReader();
204203

205-
// END TODO-STREAMDATA
206-
let responseData: any = [];
207-
208204
if (isComplexMode) {
209-
const prefixMap = await parseComplexResponse({
205+
return await parseComplexResponse({
210206
reader,
211207
abortControllerRef,
212208
update(merged, data) {
213209
mutate([...chatRequest.messages, ...merged], false);
214210
mutateStreamData([...(existingData || []), ...(data || [])], false);
215211
},
212+
onFinish(prefixMap) {
213+
if (onFinish && prefixMap.text != null) {
214+
onFinish(prefixMap.text);
215+
}
216+
},
216217
});
217-
218-
for (const [type, item] of Object.entries(prefixMap)) {
219-
if (onFinish && type === 'text') {
220-
onFinish(item as Message);
221-
}
222-
if (type === 'data') {
223-
responseData.push(item);
224-
} else {
225-
responseMessages.push(item as Message);
226-
}
227-
}
228-
return { messages: responseMessages, data: responseData };
229218
} else {
230219
const createdAt = new Date();
231220
const decode = createChunkDecoder(false);
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { formatStreamPart, parseStreamPart } from './stream-parts';
2+
3+
describe('stream-parts', () => {
4+
describe('formatStreamPart', () => {
5+
it('should escape newlines in text', () => {
6+
expect(formatStreamPart('text', 'value\nvalue')).toEqual(
7+
'0:"value\\nvalue"\n',
8+
);
9+
});
10+
11+
it('should escape newlines in data objects', () => {
12+
expect(formatStreamPart('data', [{ test: 'value\nvalue' }])).toEqual(
13+
'2:[{"test":"value\\nvalue"}]\n',
14+
);
15+
});
16+
});
17+
18+
describe('parseStreamPart', () => {
19+
it('should parse a text line', () => {
20+
const input = '0:"Hello, world!"';
21+
22+
expect(parseStreamPart(input)).toEqual({
23+
type: 'text',
24+
value: 'Hello, world!',
25+
});
26+
});
27+
28+
it('should parse a function call line', () => {
29+
const input =
30+
'1:{"function_call": {"name":"get_current_weather","arguments":"{\\"location\\": \\"Charlottesville, Virginia\\",\\"format\\": \\"celsius\\"}"}}';
31+
32+
expect(parseStreamPart(input)).toEqual({
33+
type: 'function_call',
34+
value: {
35+
function_call: {
36+
name: 'get_current_weather',
37+
arguments:
38+
'{"location": "Charlottesville, Virginia","format": "celsius"}',
39+
},
40+
},
41+
});
42+
});
43+
44+
it('should parse a data line', () => {
45+
const input = '2:[{"test":"value"}]';
46+
const expectedOutput = { type: 'data', value: [{ test: 'value' }] };
47+
expect(parseStreamPart(input)).toEqual(expectedOutput);
48+
});
49+
50+
it('should throw an error if the input does not contain a colon separator', () => {
51+
const input = 'invalid stream string';
52+
expect(() => parseStreamPart(input)).toThrow();
53+
});
54+
55+
it('should throw an error if the input contains an invalid type', () => {
56+
const input = '55:test';
57+
expect(() => parseStreamPart(input)).toThrow();
58+
});
59+
60+
it("should throw error if the input's JSON is invalid", () => {
61+
const input = '0:{"test":"value"';
62+
expect(() => parseStreamPart(input)).toThrow();
63+
});
64+
});
65+
});

‎packages/core/shared/stream-parts.ts

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import { FunctionCall, JSONValue } from './types';
2+
import { StreamString } from './utils';
3+
4+
export interface StreamPart<CODE extends string, NAME extends string, TYPE> {
5+
code: CODE;
6+
name: NAME;
7+
parse: (value: JSONValue) => { type: NAME; value: TYPE };
8+
}
9+
10+
export const textStreamPart: StreamPart<'0', 'text', string> = {
11+
code: '0',
12+
name: 'text',
13+
parse: (value: JSONValue) => {
14+
if (typeof value !== 'string') {
15+
throw new Error('"text" parts expect a string value.');
16+
}
17+
return { type: 'text', value };
18+
},
19+
};
20+
21+
export const functionCallStreamPart: StreamPart<
22+
'1',
23+
'function_call',
24+
{ function_call: FunctionCall }
25+
> = {
26+
code: '1',
27+
name: 'function_call',
28+
parse: (value: JSONValue) => {
29+
if (
30+
value == null ||
31+
typeof value !== 'object' ||
32+
!('function_call' in value)
33+
) {
34+
throw new Error(
35+
'"function_call" parts expect an object with a "function_call" property.',
36+
);
37+
}
38+
39+
const functionCall = value.function_call;
40+
41+
if (
42+
functionCall == null ||
43+
typeof functionCall !== 'object' ||
44+
!('name' in functionCall) ||
45+
!('arguments' in functionCall)
46+
) {
47+
throw new Error(
48+
'"function_call" parts expect an object with a "name" and "arguments" property.',
49+
);
50+
}
51+
52+
return {
53+
type: 'function_call',
54+
value: value as unknown as { function_call: FunctionCall },
55+
};
56+
},
57+
};
58+
59+
export const dataStreamPart: StreamPart<'2', 'data', Array<JSONValue>> = {
60+
code: '2',
61+
name: 'data',
62+
parse: (value: JSONValue) => {
63+
if (!Array.isArray(value)) {
64+
throw new Error('"data" parts expect an array value.');
65+
}
66+
67+
return { type: 'data', value };
68+
},
69+
};
70+
71+
const streamParts = [
72+
textStreamPart,
73+
functionCallStreamPart,
74+
dataStreamPart,
75+
] as const;
76+
77+
// union type of all stream parts
78+
type StreamParts =
79+
| typeof textStreamPart
80+
| typeof functionCallStreamPart
81+
| typeof dataStreamPart;
82+
83+
/**
84+
* Maps the type of a stream part to its value type.
85+
*/
86+
type StreamPartValueType = {
87+
[P in StreamParts as P['name']]: ReturnType<P['parse']>['value'];
88+
};
89+
90+
export type StreamPartType =
91+
| ReturnType<typeof textStreamPart.parse>
92+
| ReturnType<typeof functionCallStreamPart.parse>
93+
| ReturnType<typeof dataStreamPart.parse>;
94+
95+
export const streamPartsByCode = {
96+
[textStreamPart.code]: textStreamPart,
97+
[functionCallStreamPart.code]: functionCallStreamPart,
98+
[dataStreamPart.code]: dataStreamPart,
99+
} as const;
100+
101+
export const validCodes = streamParts.map(part => part.code);
102+
103+
/**
104+
* Parses a stream part from a string.
105+
*
106+
* @param line The string to parse.
107+
* @returns The parsed stream part.
108+
* @throws An error if the string cannot be parsed.
109+
*/
110+
export const parseStreamPart = (line: string): StreamPartType => {
111+
const firstSeperatorIndex = line.indexOf(':');
112+
113+
if (firstSeperatorIndex === -1) {
114+
throw new Error('Failed to parse stream string. No seperator found.');
115+
}
116+
117+
const prefix = line.slice(0, firstSeperatorIndex);
118+
119+
if (!validCodes.includes(prefix as keyof typeof streamPartsByCode)) {
120+
throw new Error(`Failed to parse stream string. Invalid code ${prefix}.`);
121+
}
122+
123+
const code = prefix as keyof typeof streamPartsByCode;
124+
125+
const textValue = line.slice(firstSeperatorIndex + 1);
126+
const jsonValue: JSONValue = JSON.parse(textValue);
127+
128+
return streamPartsByCode[code].parse(jsonValue);
129+
};
130+
131+
/**
132+
* Prepends a string with a prefix from the `StreamChunkPrefixes`, JSON-ifies it,
133+
* and appends a new line.
134+
*
135+
* It ensures type-safety for the part type and value.
136+
*/
137+
export function formatStreamPart<T extends keyof StreamPartValueType>(
138+
type: T,
139+
value: StreamPartValueType[T],
140+
): StreamString {
141+
const streamPart = streamParts.find(part => part.name === type);
142+
143+
if (!streamPart) {
144+
throw new Error(`Invalid stream part type: ${type}`);
145+
}
146+
147+
return `${streamPart.code}:${JSON.stringify(value)}\n`;
148+
}

‎packages/core/shared/utils.test.ts

+16-45
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import { createChunkDecoder, getStreamString } from './utils';
2-
import { getStreamStringTypeAndValue } from './utils';
1+
import { formatStreamPart } from './stream-parts';
2+
import { createChunkDecoder } from './utils';
33

44
describe('utils', () => {
55
describe('createChunkDecoder', () => {
66
it('should correctly decode text chunk in complex mode', () => {
77
const decoder = createChunkDecoder(true);
88

99
const encoder = new TextEncoder();
10-
const chunk = encoder.encode(getStreamString('text', 'Hello, world!'));
10+
const chunk = encoder.encode(formatStreamPart('text', 'Hello, world!'));
1111
const values = decoder(chunk);
1212

1313
expect(values).toStrictEqual([{ type: 'text', value: 'Hello, world!' }]);
@@ -24,22 +24,29 @@ describe('utils', () => {
2424

2525
const encoder = new TextEncoder();
2626
const chunk = encoder.encode(
27-
getStreamString('function_call', functionCall),
27+
formatStreamPart('function_call', {
28+
function_call: functionCall,
29+
}),
2830
);
2931
const values = decoder(chunk);
3032

3133
expect(values).toStrictEqual([
32-
{ type: 'function_call', value: functionCall },
34+
{
35+
type: 'function_call',
36+
value: {
37+
function_call: functionCall,
38+
},
39+
},
3340
]);
3441
});
3542

3643
it('should correctly decode data chunk in complex mode', () => {
37-
const data = { test: 'value' };
44+
const data = [{ test: 'value' }];
3845

3946
const decoder = createChunkDecoder(true);
4047

4148
const encoder = new TextEncoder();
42-
const chunk = encoder.encode(getStreamString('data', data));
49+
const chunk = encoder.encode(formatStreamPart('data', data));
4350
const values = decoder(chunk);
4451

4552
expect(values).toStrictEqual([{ type: 'data', value: data }]);
@@ -57,10 +64,10 @@ describe('utils', () => {
5764

5865
const enqueuedChunks = [];
5966
enqueuedChunks.push(
60-
encoder.encode(getStreamString('text', normalDecode(chunk1))),
67+
encoder.encode(formatStreamPart('text', normalDecode(chunk1))),
6168
);
6269
enqueuedChunks.push(
63-
encoder.encode(getStreamString('text', normalDecode(chunk2))),
70+
encoder.encode(formatStreamPart('text', normalDecode(chunk2))),
6471
);
6572

6673
let fullDecodedString = '';
@@ -91,40 +98,4 @@ describe('utils', () => {
9198
expect(values + secondValues).toBe('♥');
9299
});
93100
});
94-
95-
describe('getStreamStringTypeAndValue', () => {
96-
it('should correctly parse a text stream string', () => {
97-
const input = '0:Hello, world!';
98-
99-
expect(getStreamStringTypeAndValue(input)).toEqual({
100-
type: 'text',
101-
value: 'Hello, world!',
102-
});
103-
});
104-
105-
it('should correctly parse a function call stream string', () => {
106-
const input =
107-
'1:{"name":"get_current_weather","arguments":"{\\"location\\": \\"Charlottesville, Virginia\\",\\"format\\": \\"celsius\\"}"}';
108-
109-
expect(getStreamStringTypeAndValue(input)).toEqual({
110-
type: 'function_call',
111-
value: {
112-
name: 'get_current_weather',
113-
arguments:
114-
'{"location": "Charlottesville, Virginia","format": "celsius"}',
115-
},
116-
});
117-
});
118-
119-
it('should correctly parse a data stream string', () => {
120-
const input = '2:{"test":"value"}';
121-
const expectedOutput = { type: 'data', value: { test: 'value' } };
122-
expect(getStreamStringTypeAndValue(input)).toEqual(expectedOutput);
123-
});
124-
125-
it('should throw an error if the input is not a valid stream string', () => {
126-
const input = 'invalid stream string';
127-
expect(() => getStreamStringTypeAndValue(input)).toThrow();
128-
});
129-
});
130101
});

‎packages/core/shared/utils.ts

+18-58
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
import { customAlphabet } from 'nanoid/non-secure';
22
import { JSONValue } from './types';
3+
import {
4+
StreamPartType,
5+
dataStreamPart,
6+
functionCallStreamPart,
7+
parseStreamPart,
8+
streamPartsByCode,
9+
textStreamPart,
10+
} from './stream-parts';
311

412
// 7-character random string
513
export const nanoid = customAlphabet(
@@ -13,19 +21,13 @@ function createChunkDecoder(
1321
complex: false,
1422
): (chunk: Uint8Array | undefined) => string;
1523
// complex decoder signature:
16-
function createChunkDecoder(complex: true): (chunk: Uint8Array | undefined) => {
17-
type: keyof typeof StreamStringPrefixes;
18-
value: string;
19-
}[];
24+
function createChunkDecoder(
25+
complex: true,
26+
): (chunk: Uint8Array | undefined) => StreamPartType[];
2027
// combined signature for when the client calls this function with a boolean:
21-
function createChunkDecoder(complex?: boolean): (
22-
chunk: Uint8Array | undefined,
23-
) =>
24-
| {
25-
type: keyof typeof StreamStringPrefixes;
26-
value: string;
27-
}[]
28-
| string;
28+
function createChunkDecoder(
29+
complex?: boolean,
30+
): (chunk: Uint8Array | undefined) => StreamPartType[] | string;
2931
function createChunkDecoder(complex?: boolean) {
3032
const decoder = new TextDecoder();
3133

@@ -42,7 +44,7 @@ function createChunkDecoder(complex?: boolean) {
4244
.split('\n')
4345
.filter(line => line !== ''); // splitting leaves an empty string at the end
4446

45-
return decoded.map(getStreamStringTypeAndValue).filter(Boolean) as any;
47+
return decoded.map(parseStreamPart).filter(Boolean);
4648
};
4749
}
4850

@@ -69,10 +71,9 @@ export { createChunkDecoder };
6971
*```
7072
*/
7173
export const StreamStringPrefixes = {
72-
text: 0,
73-
function_call: 1,
74-
data: 2,
75-
// user_err: 3?
74+
[textStreamPart.name]: textStreamPart.code,
75+
[functionCallStreamPart.name]: functionCallStreamPart.code,
76+
[dataStreamPart.name]: dataStreamPart.code,
7677
} as const;
7778

7879
export const isStreamStringEqualToType = (
@@ -81,50 +82,9 @@ export const isStreamStringEqualToType = (
8182
): value is StreamString =>
8283
value.startsWith(`${StreamStringPrefixes[type]}:`) && value.endsWith('\n');
8384

84-
/**
85-
* Prepends a string with a prefix from the `StreamChunkPrefixes`, JSON-ifies it, and appends a new line.
86-
*/
87-
export const getStreamString = (
88-
type: keyof typeof StreamStringPrefixes,
89-
value: JSONValue,
90-
): StreamString => `${StreamStringPrefixes[type]}:${JSON.stringify(value)}\n`;
91-
9285
export type StreamString =
9386
`${(typeof StreamStringPrefixes)[keyof typeof StreamStringPrefixes]}:${string}\n`;
9487

95-
export const getStreamStringTypeAndValue = (
96-
line: string,
97-
): { type: keyof typeof StreamStringPrefixes; value: JSONValue } => {
98-
const firstSeperatorIndex = line.indexOf(':');
99-
100-
if (firstSeperatorIndex === -1) {
101-
throw new Error('Failed to parse stream string');
102-
}
103-
104-
const prefix = line.slice(0, firstSeperatorIndex);
105-
const type = Object.keys(StreamStringPrefixes).find(
106-
key =>
107-
StreamStringPrefixes[key as keyof typeof StreamStringPrefixes] ===
108-
Number(prefix),
109-
) as keyof typeof StreamStringPrefixes;
110-
111-
const val = line.slice(firstSeperatorIndex + 1);
112-
113-
let parsedVal = val;
114-
115-
if (!val) {
116-
return { type, value: '' };
117-
}
118-
119-
try {
120-
parsedVal = JSON.parse(val);
121-
} catch (e) {
122-
console.error('Failed to parse JSON value:', val);
123-
}
124-
125-
return { type, value: parsedVal };
126-
};
127-
12888
/**
12989
* A header sent to the client so it knows how to handle parsing the stream (as a deprecated text response or using the new prefixed protocol)
13090
*/

‎packages/core/streams/cohere-stream.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ describe('CohereStream', () => {
9191
const response = new StreamingTextResponse(stream, {}, data);
9292

9393
expect(await readAllChunks(response)).toEqual([
94-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
94+
'2:[{"t1":"v1"}]\n',
9595
'0:" Hello"\n',
9696
'0:","\n',
9797
'0:" world"\n',

‎packages/core/streams/huggingface-stream.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ describe('HuggingFace stream', () => {
113113
const response = new StreamingTextResponse(stream, {}, data);
114114

115115
expect(await readAllChunks(response)).toEqual([
116-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
116+
'2:[{"t1":"v1"}]\n',
117117
'0:"Hello"\n',
118118
'0:","\n',
119119
'0:" world"\n',

‎packages/core/streams/langchain-stream.test.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ describe('LangchainStream', () => {
156156
);
157157

158158
expect(await readAllChunks(response)).toEqual([
159-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
159+
'2:[{"t1":"v1"}]\n',
160160
'0:""\n',
161161
'0:"Hello"\n',
162162
'0:","\n',
@@ -270,7 +270,7 @@ describe('LangchainStream', () => {
270270
const response = new StreamingTextResponse(stream, {}, data);
271271

272272
expect(await readAllChunks(response)).toEqual([
273-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
273+
'2:[{"t1":"v1"}]\n',
274274
'0:""\n',
275275
'0:"Hello"\n',
276276
'0:","\n',

‎packages/core/streams/openai-stream.test.tsx

+5-5
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ describe('OpenAIStream', () => {
230230
const chunks = await client.readAll();
231231

232232
expect(chunks).toEqual([
233-
'1:"{\\"function_call\\": {\\"name\\": \\"get_current_weather\\", \\"arguments\\": \\"{\\\\n\\\\\\"location\\\\\\": \\\\\\"Charlottesville, Virginia\\\\\\",\\\\n\\\\\\"format\\\\\\": \\\\\\"celsius\\\\\\"\\\\n}\\"}}"\n',
233+
'1:{"function_call":{"name":"get_current_weather","arguments":"{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}\n',
234234
]);
235235
});
236236

@@ -263,8 +263,8 @@ describe('OpenAIStream', () => {
263263
const chunks = await client.readAll();
264264

265265
expect(chunks).toEqual([
266-
'2:"[{\\"fn\\":\\"get_current_weather\\"}]"\n',
267-
'1:"{\\"function_call\\": {\\"name\\": \\"get_current_weather\\", \\"arguments\\": \\"{\\\\n\\\\\\"location\\\\\\": \\\\\\"Charlottesville, Virginia\\\\\\",\\\\n\\\\\\"format\\\\\\": \\\\\\"celsius\\\\\\"\\\\n}\\"}}"\n',
266+
'2:[{"fn":"get_current_weather"}]\n',
267+
'1:{"function_call":{"name":"get_current_weather","arguments":"{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}\n',
268268
]);
269269
});
270270

@@ -327,7 +327,7 @@ describe('OpenAIStream', () => {
327327
const chunks = await client.readAll();
328328

329329
expect(chunks).toEqual([
330-
'2:"[{\\"fn\\":\\"get_current_weather\\"}]"\n',
330+
'2:[{"fn":"get_current_weather"}]\n',
331331
'0:"experimental_onFunctionCall-return-value"\n',
332332
]);
333333
});
@@ -358,7 +358,7 @@ describe('OpenAIStream', () => {
358358
const chunks = await client.readAll();
359359

360360
expect(chunks).toEqual([
361-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
361+
'2:[{"t1":"v1"}]\n',
362362
'0:"Hello"\n',
363363
'0:","\n',
364364
'0:" world"\n',

‎packages/core/streams/openai-stream.ts

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import { formatStreamPart } from '../shared/stream-parts';
12
import {
23
CreateMessage,
34
FunctionCall,
45
JSONValue,
56
Message,
67
} from '../shared/types';
7-
import { createChunkDecoder, getStreamString } from '../shared/utils';
8+
import { createChunkDecoder } from '../shared/utils';
89

910
import {
1011
AIStream,
@@ -487,7 +488,7 @@ function createFunctionCallTransformer(
487488
if (!isFunctionStreamingIn) {
488489
controller.enqueue(
489490
isComplexMode
490-
? textEncoder.encode(getStreamString('text', message))
491+
? textEncoder.encode(formatStreamPart('text', message))
491492
: chunk,
492493
);
493494
return;
@@ -546,7 +547,11 @@ function createFunctionCallTransformer(
546547
controller.enqueue(
547548
textEncoder.encode(
548549
isComplexMode
549-
? getStreamString('function_call', aggregatedResponse)
550+
? formatStreamPart(
551+
'function_call',
552+
// parse to prevent double-encoding:
553+
JSON.parse(aggregatedResponse),
554+
)
550555
: aggregatedResponse,
551556
),
552557
);
@@ -555,7 +560,7 @@ function createFunctionCallTransformer(
555560
// The user returned a string, so we just return it as a message
556561
controller.enqueue(
557562
isComplexMode
558-
? textEncoder.encode(getStreamString('text', functionResponse))
563+
? textEncoder.encode(formatStreamPart('text', functionResponse))
559564
: textEncoder.encode(functionResponse),
560565
);
561566
return;

‎packages/core/streams/replicate-stream.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ describe('ReplicateStream', () => {
104104
const response = new StreamingTextResponse(stream, {}, data);
105105

106106
expect(await readAllChunks(response)).toEqual([
107-
'2:"[{\\"t1\\":\\"v1\\"}]"\n',
107+
'2:[{"t1":"v1"}]\n',
108108
'0:" Hello,"\n',
109109
'0:" world"\n',
110110
'0:"."\n',

‎packages/core/streams/stream-data.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import { formatStreamPart } from '../shared/stream-parts';
12
import { JSONValue } from '../shared/types';
2-
import { getStreamString } from '../shared/utils';
33

44
/**
55
* A stream wrapper to send custom JSON-encoded data back to the client.
@@ -33,7 +33,7 @@ export class experimental_StreamData {
3333
// add buffered data to the stream
3434
if (self.data.length > 0) {
3535
const encodedData = self.encoder.encode(
36-
getStreamString('data', JSON.stringify(self.data)),
36+
formatStreamPart('data', self.data),
3737
);
3838
self.data = [];
3939
controller.enqueue(encodedData);
@@ -60,7 +60,7 @@ export class experimental_StreamData {
6060

6161
if (self.data.length) {
6262
const encodedData = self.encoder.encode(
63-
getStreamString('data', JSON.stringify(self.data)),
63+
formatStreamPart('data', self.data),
6464
);
6565
controller.enqueue(encodedData);
6666
}
@@ -109,7 +109,7 @@ export function createStreamDataTransformer(
109109
return new TransformStream({
110110
transform: async (chunk, controller) => {
111111
const message = decoder.decode(chunk);
112-
controller.enqueue(encoder.encode(getStreamString('text', message)));
112+
controller.enqueue(encoder.encode(formatStreamPart('text', message)));
113113
},
114114
});
115115
}

0 commit comments

Comments
 (0)
Please sign in to comment.