Skip to content

Commit 70bd2ac

Browse files
lgrammelMaxLeiter
andauthoredNov 16, 2023
Solid.js: Add complex response parsing and StreamData support to useChat (#738)
Co-authored-by: Max Leiter <max.leiter@vercel.com>
1 parent 69ca8f5 commit 70bd2ac

File tree

12 files changed

+546
-302
lines changed

12 files changed

+546
-302
lines changed
 

‎.changeset/wild-carpets-move.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
ai/solid: add experimental_StreamData support to useChat

‎examples/next-openai/app/stream-react-response/action.tsx

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
'use server';
22

33
import {
4-
JSONValue,
54
Message,
65
OpenAIStream,
76
experimental_StreamData,
@@ -108,7 +107,7 @@ export async function handler({ messages }: { messages: Message[] }) {
108107
return new experimental_StreamingReactResponse(stream, {
109108
data,
110109
ui({ content, data }) {
111-
if (data != null) {
110+
if (data?.[0] != null) {
112111
const value = data[0] as any;
113112

114113
switch (value.type) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import {
2+
OpenAIStream,
3+
StreamingTextResponse,
4+
experimental_StreamData,
5+
} from 'ai';
6+
import OpenAI from 'openai';
7+
import type { ChatCompletionCreateParams } from 'openai/resources/chat';
8+
9+
import { APIEvent } from 'solid-start/api';
10+
11+
// Create an OpenAI API client
12+
const openai = new OpenAI({
13+
apiKey: process.env['OPENAI_API_KEY'] || '',
14+
});
15+
16+
const functions: ChatCompletionCreateParams.Function[] = [
17+
{
18+
name: 'get_current_weather',
19+
description: 'Get the current weather.',
20+
parameters: {
21+
type: 'object',
22+
properties: {
23+
format: {
24+
type: 'string',
25+
enum: ['celsius', 'fahrenheit'],
26+
description: 'The temperature unit to use.',
27+
},
28+
},
29+
required: ['format'],
30+
},
31+
},
32+
{
33+
name: 'eval_code_in_browser',
34+
description: 'Execute javascript code in the browser with eval().',
35+
parameters: {
36+
type: 'object',
37+
properties: {
38+
code: {
39+
type: 'string',
40+
description: `Javascript code that will be directly executed via eval(). Do not use backticks in your response.
41+
DO NOT include any newlines in your response, and be sure to provide only valid JSON when providing the arguments object.
42+
The output of the eval() will be returned directly by the function.`,
43+
},
44+
},
45+
required: ['code'],
46+
},
47+
},
48+
];
49+
50+
export const POST = async (event: APIEvent) => {
51+
const { messages } = await event.request.json();
52+
53+
const response = await openai.chat.completions.create({
54+
model: 'gpt-3.5-turbo-0613',
55+
stream: true,
56+
messages,
57+
functions,
58+
});
59+
60+
const data = new experimental_StreamData();
61+
const stream = OpenAIStream(response, {
62+
experimental_onFunctionCall: async (
63+
{ name, arguments: args },
64+
createFunctionCallMessages,
65+
) => {
66+
if (name === 'get_current_weather') {
67+
// Call a weather API here
68+
const weatherData = {
69+
temperature: 20,
70+
unit: args.format === 'celsius' ? 'C' : 'F',
71+
};
72+
73+
data.append({
74+
text: 'Some custom data',
75+
});
76+
77+
const newMessages = createFunctionCallMessages(weatherData);
78+
return openai.chat.completions.create({
79+
messages: [...messages, ...newMessages],
80+
stream: true,
81+
model: 'gpt-3.5-turbo-0613',
82+
});
83+
}
84+
},
85+
onFinal() {
86+
data.close();
87+
},
88+
experimental_streamData: true,
89+
});
90+
91+
data.append({
92+
text: 'Hello, how are you?',
93+
});
94+
95+
// Respond with the stream
96+
return new StreamingTextResponse(stream, {}, data);
97+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import { FunctionCallHandler, Message, nanoid } from 'ai';
2+
import { useChat } from 'ai/solid';
3+
import { For, JSX } from 'solid-js';
4+
5+
export default function Chat() {
6+
const functionCallHandler: FunctionCallHandler = async (
7+
chatMessages,
8+
functionCall,
9+
) => {
10+
if (functionCall.name === 'eval_code_in_browser') {
11+
if (functionCall.arguments) {
12+
// Parsing here does not always work since it seems that some characters in generated code aren't escaped properly.
13+
const parsedFunctionCallArguments: { code: string } = JSON.parse(
14+
functionCall.arguments,
15+
);
16+
// WARNING: Do NOT do this in real-world applications!
17+
eval(parsedFunctionCallArguments.code);
18+
const functionResponse = {
19+
messages: [
20+
...chatMessages,
21+
{
22+
id: nanoid(),
23+
name: 'eval_code_in_browser',
24+
role: 'function' as const,
25+
content: parsedFunctionCallArguments.code,
26+
},
27+
],
28+
};
29+
return functionResponse;
30+
}
31+
}
32+
};
33+
34+
const { messages, input, setInput, handleSubmit, data } = useChat({
35+
api: '/api/chat-with-functions',
36+
experimental_onFunctionCall: functionCallHandler,
37+
});
38+
39+
// Generate a map of message role to text color
40+
const roleToColorMap: Record<Message['role'], string> = {
41+
system: 'red',
42+
user: 'black',
43+
function: 'blue',
44+
assistant: 'green',
45+
};
46+
47+
const handleInputChange: JSX.ChangeEventHandlerUnion<
48+
HTMLInputElement,
49+
Event
50+
> = e => {
51+
setInput(e.target.value);
52+
};
53+
54+
return (
55+
<div class="flex flex-col w-full max-w-md py-24 mx-auto stretch">
56+
<div class="bg-gray-200 mb-8">
57+
<For each={data()}>
58+
{item => (
59+
<pre class="whitespace-pre-wrap">{JSON.stringify(item)}</pre>
60+
)}
61+
</For>
62+
</div>
63+
64+
<For each={messages()}>
65+
{m => (
66+
<div
67+
class="whitespace-pre-wrap"
68+
style={{ color: roleToColorMap[m.role] }}
69+
>
70+
<strong>{`${m.role}: `}</strong>
71+
{m.content || JSON.stringify(m.function_call)}
72+
<br />
73+
<br />
74+
</div>
75+
)}
76+
</For>
77+
78+
<form onSubmit={handleSubmit}>
79+
<input
80+
class="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
81+
value={input()}
82+
placeholder="Say something..."
83+
onChange={handleInputChange}
84+
/>
85+
</form>
86+
</div>
87+
);
88+
}

‎packages/core/react/use-chat.ts

+51-198
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
import { useCallback, useEffect, useId, useRef, useState } from 'react';
22
import useSWR, { KeyedMutator } from 'swr';
3-
import { nanoid, createChunkDecoder, COMPLEX_HEADER } from '../shared/utils';
3+
import { nanoid } from '../shared/utils';
44

55
import type {
66
ChatRequest,
7+
ChatRequestOptions,
78
CreateMessage,
9+
JSONValue,
810
Message,
911
UseChatOptions,
10-
ChatRequestOptions,
11-
FunctionCall,
1212
} from '../shared/types';
13-
import { parseComplexResponse } from './parse-complex-response';
1413

14+
import { callApi } from '../shared/call-api';
15+
import { processChatStream } from '../shared/process-chat-stream';
1516
import type {
1617
ReactResponseRow,
1718
experimental_StreamingReactResponse,
1819
} from '../streams/streaming-react-response';
19-
export type { Message, CreateMessage, UseChatOptions };
20+
export type { CreateMessage, Message, UseChatOptions };
2021

2122
export type UseChatHelpers = {
2223
/** Current messages in the chat */
@@ -70,7 +71,7 @@ export type UseChatHelpers = {
7071
/** Whether the API request is in progress */
7172
isLoading: boolean;
7273
/** Additional data added on the server via StreamData */
73-
data?: any;
74+
data?: JSONValue[] | undefined;
7475
};
7576

7677
type StreamingReactResponseAction = (payload: {
@@ -82,8 +83,8 @@ const getStreamedResponse = async (
8283
api: string | StreamingReactResponseAction,
8384
chatRequest: ChatRequest,
8485
mutate: KeyedMutator<Message[]>,
85-
mutateStreamData: KeyedMutator<any[]>,
86-
existingData: any,
86+
mutateStreamData: KeyedMutator<JSONValue[] | undefined>,
87+
existingData: JSONValue[] | undefined,
8788
extraMetadataRef: React.MutableRefObject<any>,
8889
messagesRef: React.MutableRefObject<Message[]>,
8990
abortControllerRef: React.MutableRefObject<AbortController | null>,
@@ -152,10 +153,10 @@ const getStreamedResponse = async (
152153
return responseMessage;
153154
}
154155

155-
const res = await fetch(api, {
156-
method: 'POST',
157-
body: JSON.stringify({
158-
messages: constructedMessagesPayload,
156+
return await callApi({
157+
api,
158+
messages: constructedMessagesPayload,
159+
body: {
159160
data: chatRequest.data,
160161
...extraMetadataRef.current.body,
161162
...chatRequest.options?.body,
@@ -165,111 +166,26 @@ const getStreamedResponse = async (
165166
...(chatRequest.function_call !== undefined && {
166167
function_call: chatRequest.function_call,
167168
}),
168-
}),
169+
},
169170
credentials: extraMetadataRef.current.credentials,
170171
headers: {
171172
...extraMetadataRef.current.headers,
172173
...chatRequest.options?.headers,
173174
},
174-
...(abortControllerRef.current !== null && {
175-
signal: abortControllerRef.current.signal,
176-
}),
177-
}).catch(err => {
178-
// Restore the previous messages if the request fails.
179-
mutate(previousMessages, false);
180-
throw err;
175+
abortController: () => abortControllerRef.current,
176+
appendMessage(message) {
177+
mutate([...chatRequest.messages, message], false);
178+
},
179+
restoreMessagesOnFailure() {
180+
mutate(previousMessages, false);
181+
},
182+
onResponse,
183+
onUpdate(merged, data) {
184+
mutate([...chatRequest.messages, ...merged], false);
185+
mutateStreamData([...(existingData || []), ...(data || [])], false);
186+
},
187+
onFinish,
181188
});
182-
183-
if (onResponse) {
184-
try {
185-
await onResponse(res);
186-
} catch (err) {
187-
throw err;
188-
}
189-
}
190-
191-
if (!res.ok) {
192-
// Restore the previous messages if the request fails.
193-
mutate(previousMessages, false);
194-
throw new Error((await res.text()) || 'Failed to fetch the chat response.');
195-
}
196-
197-
if (!res.body) {
198-
throw new Error('The response body is empty.');
199-
}
200-
201-
const isComplexMode = res.headers.get(COMPLEX_HEADER) === 'true';
202-
const reader = res.body.getReader();
203-
204-
if (isComplexMode) {
205-
return await parseComplexResponse({
206-
reader,
207-
abortControllerRef,
208-
update(merged, data) {
209-
mutate([...chatRequest.messages, ...merged], false);
210-
mutateStreamData([...(existingData || []), ...(data || [])], false);
211-
},
212-
onFinish(prefixMap) {
213-
if (onFinish && prefixMap.text != null) {
214-
onFinish(prefixMap.text);
215-
}
216-
},
217-
});
218-
} else {
219-
const createdAt = new Date();
220-
const decode = createChunkDecoder(false);
221-
222-
// TODO-STREAMDATA: Remove this once Strem Data is not experimental
223-
let streamedResponse = '';
224-
const replyId = nanoid();
225-
let responseMessage: Message = {
226-
id: replyId,
227-
createdAt,
228-
content: '',
229-
role: 'assistant',
230-
};
231-
232-
// TODO-STREAMDATA: Remove this once Strem Data is not experimental
233-
while (true) {
234-
const { done, value } = await reader.read();
235-
if (done) {
236-
break;
237-
}
238-
// Update the chat state with the new message tokens.
239-
streamedResponse += decode(value);
240-
241-
if (streamedResponse.startsWith('{"function_call":')) {
242-
// While the function call is streaming, it will be a string.
243-
responseMessage['function_call'] = streamedResponse;
244-
} else {
245-
responseMessage['content'] = streamedResponse;
246-
}
247-
248-
mutate([...chatRequest.messages, { ...responseMessage }], false);
249-
250-
// The request has been aborted, stop reading the stream.
251-
if (abortControllerRef.current === null) {
252-
reader.cancel();
253-
break;
254-
}
255-
}
256-
257-
if (streamedResponse.startsWith('{"function_call":')) {
258-
// Once the stream is complete, the function call is parsed into an object.
259-
const parsedFunctionCall: FunctionCall =
260-
JSON.parse(streamedResponse).function_call;
261-
262-
responseMessage['function_call'] = parsedFunctionCall;
263-
264-
mutate([...chatRequest.messages, { ...responseMessage }]);
265-
}
266-
267-
if (onFinish) {
268-
onFinish(responseMessage);
269-
}
270-
271-
return responseMessage;
272-
}
273189
};
274190

275191
export function useChat({
@@ -308,10 +224,9 @@ export function useChat({
308224
null,
309225
);
310226

311-
const { data: streamData, mutate: mutateStreamData } = useSWR<any>(
312-
[chatId, 'streamData'],
313-
null,
314-
);
227+
const { data: streamData, mutate: mutateStreamData } = useSWR<
228+
JSONValue[] | undefined
229+
>([chatId, 'streamData'], null);
315230

316231
// Keep the latest messages in a ref.
317232
const messagesRef = useRef<Message[]>(messages || []);
@@ -348,89 +263,27 @@ export function useChat({
348263
const abortController = new AbortController();
349264
abortControllerRef.current = abortController;
350265

351-
while (true) {
352-
// TODO-STREAMDATA: This should be { const { messages: streamedResponseMessages, data } =
353-
// await getStreamedResponse(} once Stream Data is not experimental
354-
const messagesAndDataOrJustMessage = await getStreamedResponse(
355-
api,
356-
chatRequest,
357-
mutate,
358-
mutateStreamData,
359-
streamData,
360-
extraMetadataRef,
361-
messagesRef,
362-
abortControllerRef,
363-
onFinish,
364-
onResponse,
365-
sendExtraMessageFields,
366-
);
367-
368-
// Using experimental stream data
369-
if ('messages' in messagesAndDataOrJustMessage) {
370-
let hasFollowingResponse = false;
371-
for (const message of messagesAndDataOrJustMessage.messages) {
372-
if (
373-
message.function_call === undefined ||
374-
typeof message.function_call === 'string'
375-
) {
376-
continue;
377-
}
378-
hasFollowingResponse = true;
379-
// Streamed response is a function call, invoke the function call handler if it exists.
380-
if (experimental_onFunctionCall) {
381-
const functionCall = message.function_call;
382-
383-
// User handles the function call in their own functionCallHandler.
384-
// The "arguments" key of the function call object will still be a string which will have to be parsed in the function handler.
385-
// If the "arguments" JSON is malformed due to model error the user will have to handle that themselves.
386-
387-
const functionCallResponse: ChatRequest | void =
388-
await experimental_onFunctionCall(
389-
messagesRef.current,
390-
functionCall,
391-
);
392-
393-
// If the user does not return anything as a result of the function call, the loop will break.
394-
if (functionCallResponse === undefined) {
395-
hasFollowingResponse = false;
396-
break;
397-
}
398-
399-
// A function call response was returned.
400-
// The updated chat with function call response will be sent to the API in the next iteration of the loop.
401-
chatRequest = functionCallResponse;
402-
}
403-
}
404-
if (!hasFollowingResponse) {
405-
break;
406-
}
407-
} else {
408-
const streamedResponseMessage = messagesAndDataOrJustMessage;
409-
// TODO-STREAMDATA: Remove this once Stream Data is not experimental
410-
if (
411-
streamedResponseMessage.function_call === undefined ||
412-
typeof streamedResponseMessage.function_call === 'string'
413-
) {
414-
break;
415-
}
416-
417-
// Streamed response is a function call, invoke the function call handler if it exists.
418-
if (experimental_onFunctionCall) {
419-
const functionCall = streamedResponseMessage.function_call;
420-
const functionCallResponse: ChatRequest | void =
421-
await experimental_onFunctionCall(
422-
messagesRef.current,
423-
functionCall,
424-
);
425-
426-
// If the user does not return anything as a result of the function call, the loop will break.
427-
if (functionCallResponse === undefined) break;
428-
// A function call response was returned.
429-
// The updated chat with function call response will be sent to the API in the next iteration of the loop.
430-
chatRequest = functionCallResponse;
431-
}
432-
}
433-
}
266+
await processChatStream({
267+
getStreamedResponse: () =>
268+
getStreamedResponse(
269+
api,
270+
chatRequest,
271+
mutate,
272+
mutateStreamData,
273+
streamData!,
274+
extraMetadataRef,
275+
messagesRef,
276+
abortControllerRef,
277+
onFinish,
278+
onResponse,
279+
sendExtraMessageFields,
280+
),
281+
experimental_onFunctionCall,
282+
updateChatRequest: chatRequestParam => {
283+
chatRequest = chatRequestParam;
284+
},
285+
getCurrentMessages: () => messagesRef.current,
286+
});
434287

435288
abortControllerRef.current = null;
436289
} catch (err) {

‎packages/core/react/use-chat.ui.test.tsx

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import userEvent from '@testing-library/user-event';
44
import { mockFetch } from '../tests/utils/mock-fetch';
55
import { useChat } from './use-chat';
66

7+
// mock nanoid import
8+
jest.mock('nanoid', () => ({
9+
nanoid: () => Math.random().toString(36).slice(2, 9),
10+
}));
11+
712
describe('useChat', () => {
813
afterEach(() => {
914
jest.restoreAllMocks();

‎packages/core/shared/call-api.ts

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import { nanoid } from 'nanoid';
2+
import { parseComplexResponse } from './parse-complex-response';
3+
import { FunctionCall, JSONValue, Message } from './types';
4+
import { COMPLEX_HEADER, createChunkDecoder } from './utils';
5+
6+
export async function callApi({
7+
api,
8+
messages,
9+
body,
10+
credentials,
11+
headers,
12+
abortController,
13+
appendMessage,
14+
restoreMessagesOnFailure,
15+
onResponse,
16+
onUpdate,
17+
onFinish,
18+
}: {
19+
api: string;
20+
messages: Omit<Message, 'id'>[];
21+
body: Record<string, any>;
22+
credentials?: RequestCredentials;
23+
headers?: HeadersInit;
24+
abortController?: () => AbortController | null;
25+
restoreMessagesOnFailure: () => void;
26+
appendMessage: (message: Message) => void;
27+
onResponse?: (response: Response) => void | Promise<void>;
28+
onUpdate: (merged: Message[], data: JSONValue[] | undefined) => void;
29+
onFinish?: (message: Message) => void;
30+
}) {
31+
const response = await fetch(api, {
32+
method: 'POST',
33+
body: JSON.stringify({
34+
messages,
35+
...body,
36+
}),
37+
headers,
38+
signal: abortController?.()?.signal,
39+
credentials,
40+
}).catch(err => {
41+
restoreMessagesOnFailure();
42+
throw err;
43+
});
44+
45+
if (onResponse) {
46+
try {
47+
await onResponse(response);
48+
} catch (err) {
49+
throw err;
50+
}
51+
}
52+
53+
if (!response.ok) {
54+
restoreMessagesOnFailure();
55+
throw new Error(
56+
(await response.text()) || 'Failed to fetch the chat response.',
57+
);
58+
}
59+
60+
if (!response.body) {
61+
throw new Error('The response body is empty.');
62+
}
63+
64+
const reader = response.body.getReader();
65+
const isComplexMode = response.headers.get(COMPLEX_HEADER) === 'true';
66+
67+
if (isComplexMode) {
68+
return await parseComplexResponse({
69+
reader,
70+
abortControllerRef:
71+
abortController != null ? { current: abortController() } : undefined,
72+
update: onUpdate,
73+
onFinish(prefixMap) {
74+
if (onFinish && prefixMap.text != null) {
75+
onFinish(prefixMap.text);
76+
}
77+
},
78+
});
79+
} else {
80+
const createdAt = new Date();
81+
const decode = createChunkDecoder(false);
82+
83+
// TODO-STREAMDATA: Remove this once Stream Data is not experimental
84+
let streamedResponse = '';
85+
const replyId = nanoid();
86+
let responseMessage: Message = {
87+
id: replyId,
88+
createdAt,
89+
content: '',
90+
role: 'assistant',
91+
};
92+
93+
// TODO-STREAMDATA: Remove this once Stream Data is not experimental
94+
while (true) {
95+
const { done, value } = await reader.read();
96+
if (done) {
97+
break;
98+
}
99+
// Update the chat state with the new message tokens.
100+
streamedResponse += decode(value);
101+
102+
if (streamedResponse.startsWith('{"function_call":')) {
103+
// While the function call is streaming, it will be a string.
104+
responseMessage['function_call'] = streamedResponse;
105+
} else {
106+
responseMessage['content'] = streamedResponse;
107+
}
108+
109+
appendMessage({ ...responseMessage });
110+
111+
// The request has been aborted, stop reading the stream.
112+
if (abortController?.() === null) {
113+
reader.cancel();
114+
break;
115+
}
116+
}
117+
118+
if (streamedResponse.startsWith('{"function_call":')) {
119+
// Once the stream is complete, the function call is parsed into an object.
120+
const parsedFunctionCall: FunctionCall =
121+
JSON.parse(streamedResponse).function_call;
122+
123+
responseMessage['function_call'] = parsedFunctionCall;
124+
125+
appendMessage({ ...responseMessage });
126+
}
127+
128+
if (onFinish) {
129+
onFinish(responseMessage);
130+
}
131+
132+
return responseMessage;
133+
}
134+
}

‎packages/core/react/parse-complex-response.ts ‎packages/core/shared/parse-complex-response.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import type { FunctionCall, JSONValue, Message } from '../shared/types';
2-
import { createChunkDecoder, nanoid } from '../shared/utils';
1+
import type { FunctionCall, JSONValue, Message } from './types';
2+
import { createChunkDecoder, nanoid } from './utils';
33

44
type PrefixMap = {
55
text?: Message;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import { ChatRequest, FunctionCall, JSONValue, Message } from './types';
2+
3+
export async function processChatStream({
4+
getStreamedResponse,
5+
experimental_onFunctionCall,
6+
updateChatRequest,
7+
getCurrentMessages,
8+
}: {
9+
getStreamedResponse: () => Promise<
10+
Message | { messages: Message[]; data: JSONValue[] }
11+
>;
12+
experimental_onFunctionCall?: (
13+
chatMessages: Message[],
14+
functionCall: FunctionCall,
15+
) => Promise<void | ChatRequest>;
16+
updateChatRequest: (chatRequest: ChatRequest) => void;
17+
getCurrentMessages: () => Message[];
18+
}) {
19+
while (true) {
20+
// TODO-STREAMDATA: This should be { const { messages: streamedResponseMessages, data } =
21+
// await getStreamedResponse(} once Stream Data is not experimental
22+
const messagesAndDataOrJustMessage = await getStreamedResponse();
23+
24+
// Using experimental stream data
25+
if ('messages' in messagesAndDataOrJustMessage) {
26+
let hasFollowingResponse = false;
27+
for (const message of messagesAndDataOrJustMessage.messages) {
28+
if (
29+
message.function_call === undefined ||
30+
typeof message.function_call === 'string'
31+
) {
32+
continue;
33+
}
34+
hasFollowingResponse = true;
35+
// Streamed response is a function call, invoke the function call handler if it exists.
36+
if (experimental_onFunctionCall) {
37+
const functionCall = message.function_call;
38+
39+
// User handles the function call in their own functionCallHandler.
40+
// The "arguments" key of the function call object will still be a string which will have to be parsed in the function handler.
41+
// If the "arguments" JSON is malformed due to model error the user will have to handle that themselves.
42+
43+
const functionCallResponse: ChatRequest | void =
44+
await experimental_onFunctionCall(
45+
getCurrentMessages(),
46+
functionCall,
47+
);
48+
49+
// If the user does not return anything as a result of the function call, the loop will break.
50+
if (functionCallResponse === undefined) {
51+
hasFollowingResponse = false;
52+
break;
53+
}
54+
55+
// A function call response was returned.
56+
// The updated chat with function call response will be sent to the API in the next iteration of the loop.
57+
updateChatRequest(functionCallResponse);
58+
}
59+
}
60+
if (!hasFollowingResponse) {
61+
break;
62+
}
63+
} else {
64+
const streamedResponseMessage = messagesAndDataOrJustMessage;
65+
// TODO-STREAMDATA: Remove this once Stream Data is not experimental
66+
if (
67+
streamedResponseMessage.function_call === undefined ||
68+
typeof streamedResponseMessage.function_call === 'string'
69+
) {
70+
break;
71+
}
72+
73+
// Streamed response is a function call, invoke the function call handler if it exists.
74+
if (experimental_onFunctionCall) {
75+
const functionCall = streamedResponseMessage.function_call;
76+
const functionCallResponse: ChatRequest | void =
77+
await experimental_onFunctionCall(getCurrentMessages(), functionCall);
78+
79+
// If the user does not return anything as a result of the function call, the loop will break.
80+
if (functionCallResponse === undefined) break;
81+
// A function call response was returned.
82+
// The updated chat with function call response will be sent to the API in the next iteration of the loop.
83+
updateChatRequest(functionCallResponse);
84+
}
85+
}
86+
}
87+
}

‎packages/core/solid/use-chat.ts

+75-99
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import { Accessor, Resource, Setter, createSignal } from 'solid-js';
22
import { useSWRStore } from 'solid-swr-store';
33
import { createSWRStore } from 'swr-store';
4-
4+
import { callApi } from '../shared/call-api';
5+
import { processChatStream } from '../shared/process-chat-stream';
56
import type {
7+
ChatRequest,
68
CreateMessage,
9+
JSONValue,
710
Message,
811
RequestOptions,
912
UseChatOptions,
1013
} from '../shared/types';
11-
import { createChunkDecoder, nanoid } from '../shared/utils';
14+
import { nanoid } from '../shared/utils';
1215

1316
export type { CreateMessage, Message, UseChatOptions };
1417

@@ -51,6 +54,8 @@ export type UseChatHelpers = {
5154
handleSubmit: (e: any) => void;
5255
/** Whether the API request is in progress */
5356
isLoading: Accessor<boolean>;
57+
/** Additional data added on the server via StreamData */
58+
data: Accessor<JSONValue[] | undefined>;
5459
};
5560

5661
let uniqueId = 0;
@@ -68,6 +73,7 @@ export function useChat({
6873
initialMessages = [],
6974
initialInput = '',
7075
sendExtraMessageFields,
76+
experimental_onFunctionCall,
7177
onResponse,
7278
onFinish,
7379
onError,
@@ -79,9 +85,11 @@ export function useChat({
7985
const chatId = id || `chat-${uniqueId++}`;
8086

8187
const key = `${api}|${chatId}`;
82-
const data = useSWRStore(chatApiStore, () => [key], {
88+
89+
// Because of the `initialData` option, the `data` will never be `undefined`:
90+
const messages = useSWRStore(chatApiStore, () => [key], {
8391
initialData: initialMessages,
84-
});
92+
}) as Resource<Message[]>;
8593

8694
const mutate = (data: Message[]) => {
8795
store[key] = data;
@@ -91,10 +99,10 @@ export function useChat({
9199
});
92100
};
93101

94-
// Because of the `initialData` option, the `data` will never be `undefined`.
95-
const messages = data as Resource<Message[]>;
96-
97102
const [error, setError] = createSignal<undefined | Error>(undefined);
103+
const [streamData, setStreamData] = createSignal<JSONValue[] | undefined>(
104+
undefined,
105+
);
98106
const [isLoading, setIsLoading] = createSignal(false);
99107

100108
let abortController: AbortController | null = null;
@@ -108,107 +116,74 @@ export function useChat({
108116

109117
abortController = new AbortController();
110118

119+
const getCurrentMessages = () =>
120+
chatApiStore.get([key], {
121+
shouldRevalidate: false,
122+
});
123+
111124
// Do an optimistic update to the chat state to show the updated messages
112125
// immediately.
113-
const previousMessages = chatApiStore.get([key], {
114-
shouldRevalidate: false,
115-
});
126+
const previousMessages = getCurrentMessages();
116127
mutate(messagesSnapshot);
117128

118-
const res = await fetch(api, {
119-
method: 'POST',
120-
body: JSON.stringify({
121-
messages: sendExtraMessageFields
122-
? messagesSnapshot
123-
: messagesSnapshot.map(
124-
({ role, content, name, function_call }) => ({
125-
role,
126-
content,
127-
...(name !== undefined && { name }),
128-
...(function_call !== undefined && {
129-
function_call: function_call,
129+
let chatRequest: ChatRequest = {
130+
messages: messagesSnapshot,
131+
options,
132+
};
133+
134+
await processChatStream({
135+
getStreamedResponse: async () => {
136+
const existingData = streamData() ?? [];
137+
138+
return await callApi({
139+
api,
140+
messages: sendExtraMessageFields
141+
? chatRequest.messages
142+
: chatRequest.messages.map(
143+
({ role, content, name, function_call }) => ({
144+
role,
145+
content,
146+
...(name !== undefined && { name }),
147+
...(function_call !== undefined && {
148+
function_call,
149+
}),
130150
}),
131-
}),
132-
),
133-
...body,
134-
...options?.body,
135-
}),
136-
headers: {
137-
...headers,
138-
...options?.headers,
151+
),
152+
body: {
153+
...body,
154+
...options?.body,
155+
},
156+
headers: {
157+
...headers,
158+
...options?.headers,
159+
},
160+
abortController: () => abortController,
161+
credentials,
162+
onResponse,
163+
onUpdate(merged, data) {
164+
mutate([...chatRequest.messages, ...merged]);
165+
setStreamData([...existingData, ...(data ?? [])]);
166+
},
167+
onFinish,
168+
appendMessage(message) {
169+
mutate([...chatRequest.messages, message]);
170+
},
171+
restoreMessagesOnFailure() {
172+
// Restore the previous messages if the request fails.
173+
if (previousMessages.status === 'success') {
174+
mutate(previousMessages.data);
175+
}
176+
},
177+
});
178+
},
179+
experimental_onFunctionCall,
180+
updateChatRequest(newChatRequest) {
181+
chatRequest = newChatRequest;
139182
},
140-
signal: abortController.signal,
141-
credentials,
142-
}).catch(err => {
143-
// Restore the previous messages if the request fails.
144-
if (previousMessages.status === 'success') {
145-
mutate(previousMessages.data);
146-
}
147-
throw err;
183+
getCurrentMessages: () => getCurrentMessages().data,
148184
});
149185

150-
if (onResponse) {
151-
try {
152-
await onResponse(res);
153-
} catch (err) {
154-
throw err;
155-
}
156-
}
157-
158-
if (!res.ok) {
159-
// Restore the previous messages if the request fails.
160-
if (previousMessages.status === 'success') {
161-
mutate(previousMessages.data);
162-
}
163-
throw new Error(
164-
(await res.text()) || 'Failed to fetch the chat response.',
165-
);
166-
}
167-
if (!res.body) {
168-
throw new Error('The response body is empty.');
169-
}
170-
171-
let result = '';
172-
const createdAt = new Date();
173-
const replyId = nanoid();
174-
const reader = res.body.getReader();
175-
const decoder = createChunkDecoder();
176-
177-
while (true) {
178-
const { done, value } = await reader.read();
179-
if (done) {
180-
break;
181-
}
182-
// Update the chat state with the new message tokens.
183-
result += decoder(value);
184-
mutate([
185-
...messagesSnapshot,
186-
{
187-
id: replyId,
188-
createdAt,
189-
content: result,
190-
role: 'assistant',
191-
},
192-
]);
193-
194-
// The request has been aborted, stop reading the stream.
195-
if (abortController === null) {
196-
reader.cancel();
197-
break;
198-
}
199-
}
200-
201-
if (onFinish) {
202-
onFinish({
203-
id: replyId,
204-
createdAt,
205-
content: result,
206-
role: 'assistant',
207-
});
208-
}
209-
210186
abortController = null;
211-
return result;
212187
} catch (err) {
213188
// Ignore abort errors as they are expected.
214189
if ((err as any).name === 'AbortError') {
@@ -283,5 +258,6 @@ export function useChat({
283258
setInput,
284259
handleSubmit,
285260
isLoading,
261+
data: streamData,
286262
};
287263
}

‎packages/core/streams/streaming-react-response.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
* between the rows, but flushing the full payload on each row.
99
*/
1010

11-
import { parseComplexResponse } from '../react/parse-complex-response';
11+
import { parseComplexResponse } from '../shared/parse-complex-response';
1212
import { JSONValue } from '../shared/types';
1313
import { createChunkDecoder } from '../shared/utils';
1414
import { experimental_StreamData } from './stream-data';

0 commit comments

Comments
 (0)
Please sign in to comment.