Skip to content

Commit

Permalink
Implement Cloudflare Workers AI endpoint (#907) (#972)
Browse files Browse the repository at this point in the history
* Implement Cloudflare Workers AI endpoint (#907)

* Renamed to Cloudflare Workers AI in docs

* Add note about sampling parameters

* clean up env example
  • Loading branch information
nsarrazin committed Apr 3, 2024
1 parent a01ed5a commit cb000d3
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ MONGODB_DIRECT_CONNECTION=false
COOKIE_NAME=hf-chat
HF_TOKEN=#hf_<token> from https://huggingface.co/settings/token
HF_API_ROOT=https://api-inference.huggingface.co/models

OPENAI_API_KEY=#your openai api key here
ANTHROPIC_API_KEY=#your anthropic api key here
CLOUDFLARE_ACCOUNT_ID=#your cloudflare account id here
CLOUDFLARE_API_TOKEN=#your cloudflare api token here

HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead

Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,38 @@ You can also set `"service" : "lambda"` to use a lambda instance.

You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.

#### Cloudflare Workers AI

You can also use Cloudflare Workers AI to run your own models with serverless inference.

You will need to have a Cloudflare account, then get your [account ID](https://developers.cloudflare.com/fundamentals/setup/find-account-and-zone-ids/) as well as your [API token](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-an-api-token) for Workers AI.

You can either specify them directly in your `.env.local` using the `CLOUDFLARE_ACCOUNT_ID` and `CLOUDFLARE_API_TOKEN` variables, or you can set them directly in the endpoint config.

You can find the list of models available on Cloudflare [here](https://developers.cloudflare.com/workers-ai/models/#text-generation).

```env
{
"name" : "nousresearch/hermes-2-pro-mistral-7b",
"tokenizer": "nousresearch/hermes-2-pro-mistral-7b",
"parameters": {
"stop": ["<|im_end|>"]
},
"endpoints" : [
{
"type" : "cloudflare"
<!-- optionally specify these
"accountId": "your-account-id",
"authToken": "your-api-token"
-->
}
]
}
```

> [!NOTE]
> Cloudlare Workers AI currently do not support custom sampling parameters like temperature, top_p, etc.
##### Google Vertex models

Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).
Expand Down
134 changes: 134 additions & 0 deletions src/lib/server/endpoints/cloudflare/endpointCloudflare.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_API_TOKEN } from "$env/static/private";

export const endpointCloudflareParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("cloudflare"),
accountId: z.string().default(CLOUDFLARE_ACCOUNT_ID),
apiToken: z.string().default(CLOUDFLARE_API_TOKEN),
});

export async function endpointCloudflare(
input: z.input<typeof endpointCloudflareParametersSchema>
): Promise<Endpoint> {
const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;

return async ({ messages, preprompt }) => {
let messagesFormatted = messages.map((message) => ({
role: message.from,
content: message.content,
}));

if (messagesFormatted?.[0]?.role !== "system") {
messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
}

const payload = JSON.stringify({
messages: messagesFormatted,
stream: true,
});

const res = await fetch(apiURL, {
method: "POST",
headers: {
Authorization: `Bearer ${apiToken}`,
"Content-Type": "application/json",
},
body: payload,
});

if (!res.ok) {
throw new Error(`Failed to generate text: ${await res.text()}`);
}

const encoder = new TextDecoderStream();
const reader = res.body?.pipeThrough(encoder).getReader();

return (async function* () {
let stop = false;
let generatedText = "";
let tokenId = 0;
let accumulatedData = ""; // Buffer to accumulate data chunks

while (!stop) {
const out = await reader?.read();

// If it's done, we cancel
if (out?.done) {
reader?.cancel();
return;
}

if (!out?.value) {
return;
}

// Accumulate the data chunk
accumulatedData += out.value;

// Process each complete JSON object in the accumulated data
while (accumulatedData.includes("\n")) {
// Assuming each JSON object ends with a newline
const endIndex = accumulatedData.indexOf("\n");
let jsonString = accumulatedData.substring(0, endIndex).trim();

// Remove the processed part from the buffer
accumulatedData = accumulatedData.substring(endIndex + 1);

if (jsonString.startsWith("data: ")) {
jsonString = jsonString.slice(6);
let data = null;

if (jsonString === "[DONE]") {
stop = true;

yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: generatedText,
details: null,
} satisfies TextGenerationStreamOutput;
reader?.cancel();

continue;
}

try {
data = JSON.parse(jsonString);
} catch (e) {
console.error("Failed to parse JSON", e);
console.error("Problematic JSON string:", jsonString);
continue; // Skip this iteration and try the next chunk
}

// Handle the parsed data
if (data.response) {
generatedText += data.response ?? "";
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: data.response ?? "",
logprob: 0,
special: false,
},
generated_text: null,
details: null,
};
yield output;
}
}
}
}
})();
};
}

export default endpointCloudflare;
5 changes: 5 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import {
endpointAnthropicParametersSchema,
} from "./anthropic/endpointAnthropic";
import type { Model } from "$lib/types/Model";
import endpointCloudflare, {
endpointCloudflareParametersSchema,
} from "./cloudflare/endpointCloudflare";

// parameters passed when generating text
export interface EndpointParameters {
Expand Down Expand Up @@ -42,6 +45,7 @@ export const endpoints = {
llamacpp: endpointLlamacpp,
ollama: endpointOllama,
vertex: endpointVertex,
cloudflare: endpointCloudflare,
};

export const endpointSchema = z.discriminatedUnion("type", [
Expand All @@ -52,5 +56,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
endpointLlamacppParametersSchema,
endpointOllamaParametersSchema,
endpointVertexParametersSchema,
endpointCloudflareParametersSchema,
]);
export default endpoints;
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
return endpoints.ollama(args);
case "vertex":
return await endpoints.vertex(args);
case "cloudflare":
return await endpoints.cloudflare(args);
default:
// for legacy reason
return endpoints.tgi(args);
Expand Down

0 comments on commit cb000d3

Please sign in to comment.