Skip to content

Commit ed1e278

Browse files
nick-inkeepMaxLeiter
andauthoredFeb 9, 2024
fix: Message annotations are compatible for any message type (#959)
Co-authored-by: Max Leiter <maxwell.leiter@gmail.com>
1 parent 2076ae6 commit ed1e278

File tree

6 files changed

+158
-45
lines changed

6 files changed

+158
-45
lines changed
 

‎.changeset/stupid-news-greet.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
Message annotations handling for all Message types

‎examples/next-openai/app/api/chat-with-functions/route.ts

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ export async function POST(req: Request) {
7575
text: 'Some custom data',
7676
});
7777

78+
data.appendMessageAnnotation({ current_weather: weatherData });
79+
7880
const newMessages = createFunctionCallMessages(weatherData);
7981
return openai.chat.completions.create({
8082
messages: [...messages, ...newMessages],

‎examples/next-openai/app/function-calling/page.tsx

+6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ export default function Chat() {
6161
>
6262
<strong>{`${m.role}: `}</strong>
6363
{m.content || JSON.stringify(m.function_call)}
64+
{m.annotations ? (
65+
<div>
66+
<br />
67+
<em>Annotations:</em> {JSON.stringify(m.annotations)}
68+
</div>
69+
) : null}
6470
<br />
6571
<br />
6672
</div>

‎packages/core/shared/parse-complex-response.test.ts

+89-4
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ describe('parseComplexResponse function', () => {
231231
// Execute the parser function
232232
const result = await parseComplexResponse({
233233
reader: createTestReader([
234-
'0:"Sample text message."\n',
235234
'8:[{"key":"value"}, 2]\n',
235+
'0:"Sample text message."\n',
236236
]),
237237
abortControllerRef: { current: new AbortController() },
238238
update: mockUpdate,
@@ -243,9 +243,7 @@ describe('parseComplexResponse function', () => {
243243
// check the mockUpdate call:
244244
expect(mockUpdate).toHaveBeenCalledTimes(2);
245245

246-
expect(mockUpdate.mock.calls[0][0]).toEqual([
247-
assistantTextMessage('Sample text message.'),
248-
]);
246+
expect(mockUpdate.mock.calls[0][0]).toEqual([]);
249247

250248
expect(mockUpdate.mock.calls[1][0]).toEqual([
251249
{
@@ -265,4 +263,91 @@ describe('parseComplexResponse function', () => {
265263
data: [],
266264
});
267265
});
266+
267+
it('should parse a combination of a function_call and message annotations', async () => {
268+
const mockUpdate = vi.fn();
269+
270+
// Execute the parser function
271+
const result = await parseComplexResponse({
272+
reader: createTestReader([
273+
'1:{"function_call":{"name":"get_current_weather","arguments":"{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}}\n',
274+
'8:[{"key":"value"}, 2]\n',
275+
'8:[null,false,"text"]\n',
276+
]),
277+
abortControllerRef: { current: new AbortController() },
278+
update: mockUpdate,
279+
generateId: () => 'test-id',
280+
getCurrentDate: () => new Date(0),
281+
});
282+
283+
// check the mockUpdate call:
284+
expect(mockUpdate).toHaveBeenCalledTimes(3);
285+
286+
expect(mockUpdate.mock.calls[0][0]).toEqual([
287+
{
288+
content: '',
289+
createdAt: new Date(0),
290+
id: 'test-id',
291+
role: 'assistant',
292+
function_call: {
293+
name: 'get_current_weather',
294+
arguments:
295+
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
296+
},
297+
name: 'get_current_weather',
298+
},
299+
]);
300+
301+
expect(mockUpdate.mock.calls[1][0]).toEqual([
302+
{
303+
content: '',
304+
createdAt: new Date(0),
305+
id: 'test-id',
306+
role: 'assistant',
307+
function_call: {
308+
name: 'get_current_weather',
309+
arguments:
310+
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
311+
},
312+
name: 'get_current_weather',
313+
annotations: [{ key: 'value' }, 2],
314+
},
315+
]);
316+
317+
expect(mockUpdate.mock.calls[2][0]).toEqual([
318+
{
319+
content: '',
320+
createdAt: new Date(0),
321+
id: 'test-id',
322+
role: 'assistant',
323+
function_call: {
324+
name: 'get_current_weather',
325+
arguments:
326+
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
327+
},
328+
name: 'get_current_weather',
329+
annotations: [{ key: 'value' }, 2, null, false, 'text'],
330+
},
331+
]);
332+
333+
// check the result
334+
expect(result).toEqual({
335+
messages: [
336+
{
337+
content: '',
338+
createdAt: new Date(0),
339+
id: 'test-id',
340+
role: 'assistant',
341+
function_call: {
342+
name: 'get_current_weather',
343+
arguments:
344+
'{\n"location": "Charlottesville, Virginia",\n"format": "celsius"\n}',
345+
},
346+
name: 'get_current_weather',
347+
annotations: [{ key: 'value' }, 2, null, false, 'text'],
348+
},
349+
],
350+
data: [],
351+
});
352+
});
268353
});

‎packages/core/shared/parse-complex-response.ts

+53-39
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,12 @@ type PrefixMap = {
1515
data: JSONValue[];
1616
};
1717

18-
function initializeMessage({
19-
generateId,
20-
...rest
21-
}: {
22-
generateId: () => string;
23-
content: string;
24-
createdAt: Date;
25-
annotations?: JSONValue[];
26-
}): Message {
27-
return {
28-
id: generateId(),
29-
role: 'assistant',
30-
...rest,
31-
};
18+
function assignAnnotationsToMessage<T extends Message | null | undefined>(
19+
message: T,
20+
annotations: JSONValue[] | undefined,
21+
): T {
22+
if (!message || !annotations || !annotations.length) return message;
23+
return { ...message, annotations: [...annotations] } as T;
3224
}
3325

3426
export async function parseComplexResponse({
@@ -53,6 +45,9 @@ export async function parseComplexResponse({
5345
data: [],
5446
};
5547

48+
// keep list of current message annotations for message
49+
let message_annotations: JSONValue[] | undefined = undefined;
50+
5651
// we create a map of each prefix, and for each prefixed message we push to the map
5752
for await (const { type, value } of readDataStream(reader, {
5853
isAborted: () => abortControllerRef?.current === null,
@@ -73,24 +68,7 @@ export async function parseComplexResponse({
7368
}
7469
}
7570

76-
if (type == 'message_annotations') {
77-
if (prefixMap['text']) {
78-
prefixMap['text'] = {
79-
...prefixMap['text'],
80-
annotations: [...(prefixMap['text'].annotations || []), ...value],
81-
};
82-
} else {
83-
prefixMap['text'] = {
84-
id: generateId(),
85-
role: 'assistant',
86-
content: '',
87-
annotations: [...value],
88-
createdAt,
89-
};
90-
}
91-
}
92-
93-
let functionCallMessage: Message | null = null;
71+
let functionCallMessage: Message | null | undefined = null;
9472

9573
if (type === 'function_call') {
9674
prefixMap['function_call'] = {
@@ -105,7 +83,7 @@ export async function parseComplexResponse({
10583
functionCallMessage = prefixMap['function_call'];
10684
}
10785

108-
let toolCallMessage: Message | null = null;
86+
let toolCallMessage: Message | null | undefined = null;
10987

11088
if (type === 'tool_calls') {
11189
prefixMap['tool_calls'] = {
@@ -123,14 +101,50 @@ export async function parseComplexResponse({
123101
prefixMap['data'].push(...value);
124102
}
125103

126-
const responseMessage = prefixMap['text'];
104+
let responseMessage = prefixMap['text'];
105+
106+
if (type === 'message_annotations') {
107+
if (!message_annotations) {
108+
message_annotations = [...value];
109+
} else {
110+
message_annotations.push(...value);
111+
}
112+
113+
// Update any existing message with the latest annotations
114+
functionCallMessage = assignAnnotationsToMessage(
115+
prefixMap['function_call'],
116+
message_annotations,
117+
);
118+
toolCallMessage = assignAnnotationsToMessage(
119+
prefixMap['tool_calls'],
120+
message_annotations,
121+
);
122+
responseMessage = assignAnnotationsToMessage(
123+
prefixMap['text'],
124+
message_annotations,
125+
);
126+
}
127+
128+
// keeps the prefixMap up to date with the latest annotations, even if annotations preceded the message
129+
if (message_annotations?.length) {
130+
const messagePrefixKeys: (keyof PrefixMap)[] = [
131+
'text',
132+
'function_call',
133+
'tool_calls',
134+
];
135+
messagePrefixKeys.forEach(key => {
136+
if (prefixMap[key]) {
137+
(prefixMap[key] as Message).annotations = [...message_annotations!];
138+
}
139+
});
140+
}
127141

128142
// We add function & tool calls and response messages to the messages[], but data is its own thing
129-
const merged = [
130-
functionCallMessage,
131-
toolCallMessage,
132-
responseMessage,
133-
].filter(Boolean) as Message[];
143+
const merged = [functionCallMessage, toolCallMessage, responseMessage]
144+
.filter(Boolean)
145+
.map(message => ({
146+
...assignAnnotationsToMessage(message, message_annotations),
147+
})) as Message[];
134148

135149
update(merged, [...prefixMap['data']]); // make a copy of the data array
136150
}

‎packages/core/streams/stream-data.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ export class experimental_StreamData {
4242
}
4343

4444
if (self.messageAnnotations.length) {
45-
const encodedmessageAnnotations = self.encoder.encode(
45+
const encodedMessageAnnotations = self.encoder.encode(
4646
formatStreamPart('message_annotations', self.messageAnnotations),
4747
);
48-
controller.enqueue(encodedmessageAnnotations);
48+
self.messageAnnotations = [];
49+
controller.enqueue(encodedMessageAnnotations);
4950
}
5051

5152
controller.enqueue(chunk);

0 commit comments

Comments
 (0)
Please sign in to comment.