Skip to content

Commit

Permalink
Add native run binding for Workers AI (#5371)
Browse files Browse the repository at this point in the history
* Improve workers ai binding using wrappedBindings

* Update .changeset/hot-geese-deliver.md

Co-authored-by: MrBBot <me@mrbbot.dev>

---------

Co-authored-by: MrBBot <me@mrbbot.dev>
  • Loading branch information
G4brym and mrbbot committed Apr 1, 2024
1 parent d994066 commit 77152f3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 21 deletions.
7 changes: 7 additions & 0 deletions .changeset/hot-geese-deliver.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"wrangler": minor
---

feature: remove requirement for `@cloudflare/ai` package to use Workers AI

Previously, to get the correct Workers AI API, you needed to wrap your `env.AI` binding with `new Ai()` from `@cloudflare/ai`. This change moves the contents of `@cloudflare/ai` into the Workers runtime itself, meaning `env.AI` is now an instance of `Ai`, without the need for wrapping.
3 changes: 2 additions & 1 deletion fixtures/ai-app/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ export default {

return Response.json({
binding: env.AI,
fetcher: env.AI.fetch.toString(),
run: typeof env.AI.run,
fetch: typeof env.AI.fetch,
});
},
};
10 changes: 8 additions & 2 deletions fixtures/ai-app/tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ describe("'wrangler dev' correctly renders pages", () => {
const response = await fetch(`http://${ip}:${port}/`);
const content = await response.json();
expect(content).toEqual({
binding: {},
fetcher: "function fetch() { [native code] }",
binding: {
fetcher: {},
lastRequestId: null,
logs: [],
options: {},
},
fetch: "function",
run: "function",
});
});
});
12 changes: 11 additions & 1 deletion packages/wrangler/src/ai/fetcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@ import { performApiFetch } from "../cfetch/internal";
import { getAccountId } from "../user";
import type { Request } from "miniflare";

export async function AIFetcher(request: Request) {
export const EXTERNAL_AI_WORKER_NAME = "__WRANGLER_EXTERNAL_AI_WORKER";

export const EXTERNAL_AI_WORKER_SCRIPT = `
import { Ai } from 'cloudflare-internal:ai-api'
export default function (env) {
return new Ai(env.FETCHER);
}
`;

export async function AIFetcher(request: Request): Promise<Response> {
const accountId = await getAccountId();

request.headers.delete("Host");
Expand Down
17 changes: 8 additions & 9 deletions packages/wrangler/src/api/integrations/platform/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ async function getMiniflareOptionsFromConfig(
durableObjects: rawConfig["durable_objects"],
});

const { bindingOptions, externalDurableObjectWorker } =
buildMiniflareBindingOptions({
name: undefined,
bindings,
workerDefinitions,
queueConsumers: undefined,
serviceBindings: {},
});
const { bindingOptions, externalWorkers } = buildMiniflareBindingOptions({
name: undefined,
bindings,
workerDefinitions,
queueConsumers: undefined,
serviceBindings: {},
});

const persistOptions = getMiniflarePersistOptions(options.persist);

Expand All @@ -156,7 +155,7 @@ async function getMiniflareOptionsFromConfig(
...bindingOptions.serviceBindings,
},
},
externalDurableObjectWorker,
...externalWorkers,
],
...persistOptions,
};
Expand Down
41 changes: 33 additions & 8 deletions packages/wrangler/src/dev/miniflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import { randomUUID } from "node:crypto";
import { readFileSync, realpathSync } from "node:fs";
import path from "node:path";
import { Log, LogLevel, Miniflare, Mutex, TypedEventTarget } from "miniflare";
import { AIFetcher } from "../ai/fetcher";
import {
AIFetcher,
EXTERNAL_AI_WORKER_NAME,
EXTERNAL_AI_WORKER_SCRIPT,
} from "../ai/fetcher";
import { ModuleTypeToRuleType } from "../deployment-bundle/module-collection";
import { withSourceURLs } from "../deployment-bundle/source-url";
import { getHttpsOptions } from "../https-options";
Expand Down Expand Up @@ -253,6 +257,7 @@ type WorkerOptionsBindings = Pick<
| "hyperdrives"
| "durableObjects"
| "serviceBindings"
| "wrappedBindings"
>;

type MiniflareBindingsConfig = Pick<
Expand All @@ -270,7 +275,7 @@ type MiniflareBindingsConfig = Pick<
export function buildMiniflareBindingOptions(config: MiniflareBindingsConfig): {
bindingOptions: WorkerOptionsBindings;
internalObjects: CfDurableObject[];
externalDurableObjectWorker: WorkerOptions;
externalWorkers: WorkerOptions[];
} {
const bindings = config.bindings;

Expand Down Expand Up @@ -299,13 +304,14 @@ export function buildMiniflareBindingOptions(config: MiniflareBindingsConfig): {
// registered in the dev registry)
const internalObjects: CfDurableObject[] = [];
const externalObjects: CfDurableObject[] = [];
const externalWorkers: WorkerOptions[] = [];
for (const binding of bindings.durable_objects?.bindings ?? []) {
const internal =
binding.script_name === undefined || binding.script_name === config.name;
(internal ? internalObjects : externalObjects).push(binding);
}
// Setup Durable Object bindings and proxy worker
const externalDurableObjectWorker: WorkerOptions = {
externalWorkers.push({
name: EXTERNAL_DURABLE_OBJECTS_WORKER_NAME,
// Bind all internal objects, so they're accessible by all other sessions
// that proxy requests for our objects to this worker
Expand Down Expand Up @@ -353,10 +359,27 @@ export function buildMiniflareBindingOptions(config: MiniflareBindingsConfig): {
}
})
.join("\n"),
};
});

const wrappedBindings: WorkerOptions["wrappedBindings"] = {};
if (bindings.ai?.binding) {
config.serviceBindings[bindings.ai.binding] = AIFetcher;
externalWorkers.push({
name: EXTERNAL_AI_WORKER_NAME,
modules: [
{
type: "ESModule",
path: "index.mjs",
contents: EXTERNAL_AI_WORKER_SCRIPT,
},
],
serviceBindings: {
FETCHER: AIFetcher,
},
});

wrappedBindings[bindings.ai.binding] = {
scriptName: EXTERNAL_AI_WORKER_NAME,
};
}

const bindingOptions = {
Expand Down Expand Up @@ -411,13 +434,15 @@ export function buildMiniflareBindingOptions(config: MiniflareBindingsConfig): {
]),

serviceBindings: config.serviceBindings,

wrappedBindings: wrappedBindings,
// TODO: check multi worker service bindings also supported
};

return {
bindingOptions,
internalObjects,
externalDurableObjectWorker,
externalWorkers,
};
}

Expand Down Expand Up @@ -584,7 +609,7 @@ async function buildMiniflareOptions(
: undefined;

const sourceOptions = await buildSourceOptions(config);
const { bindingOptions, internalObjects, externalDurableObjectWorker } =
const { bindingOptions, internalObjects, externalWorkers } =
buildMiniflareBindingOptions(config);
const sitesOptions = buildSitesOptions(config);
const persistOptions = buildPersistOptions(config.localPersistencePath);
Expand Down Expand Up @@ -625,7 +650,7 @@ async function buildMiniflareOptions(
...bindingOptions,
...sitesOptions,
},
externalDurableObjectWorker,
...externalWorkers,
],
};
return { options, internalObjects };
Expand Down

0 comments on commit 77152f3

Please sign in to comment.