Skip to content

Commit 0f6bc4e

Browse files
authoredMay 14, 2024··
feat (ai/core): add embed function (#1575)
1 parent 1009594 commit 0f6bc4e

26 files changed

+963
-22
lines changed
 

Diff for: ‎.changeset/witty-beds-sell.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
'@ai-sdk/provider': patch
3+
'@ai-sdk/mistral': patch
4+
'@ai-sdk/openai': patch
5+
'ai': patch
6+
---
7+
8+
feat (ai/core): add embed function

Diff for: ‎content/docs/03-ai-sdk-core/30-embeddings.mdx

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
---
2+
title: Embeddings
3+
description: Learn how to embed values with the Vercel AI SDK.
4+
---
5+
6+
# Embeddings
7+
8+
Embeddings are a way to represent words, phrases, or images as vectors in a high-dimensional space.
9+
In this space, similar words are close to each other, and the distance between words can be used to measure their similarity.
10+
11+
## Embedding a Single Value
12+
13+
The Vercel AI SDK provides the `embed` function to embed single values, which is useful for tasks such as finding similar words
14+
or phrases or clustering text. You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.
15+
16+
```tsx
17+
import { embed } from 'ai';
18+
import { openai } from '@ai-sdk/openai';
19+
20+
const { embedding } = await embed({
21+
model: openai.embedding('text-embedding-3-small'),
22+
value: 'sunny day at the beach',
23+
});
24+
```

Diff for: ‎examples/ai-core/src/embed/mistral.ts

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { mistral } from '@ai-sdk/mistral';
2+
import { embed } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embedding } = await embed({
9+
model: mistral.embedding('mistral-embed'),
10+
value: 'sunny day at the beach',
11+
});
12+
13+
console.log(embedding);
14+
}
15+
16+
main().catch(console.error);

Diff for: ‎examples/ai-core/src/embed/openai.ts

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { embed } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embedding } = await embed({
9+
model: openai.embedding('text-embedding-3-small'),
10+
value: 'sunny day at the beach',
11+
});
12+
13+
console.log(embedding);
14+
}
15+
16+
main().catch(console.error);

Diff for: ‎packages/core/core/embed/embed.test.ts

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import assert from 'node:assert';
2+
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1';
3+
import { embed } from './embed';
4+
5+
const dummyEmbedding = [0.1, 0.2, 0.3];
6+
const testValue = 'sunny day at the beach';
7+
8+
describe('result.embedding', () => {
9+
it('should generate embedding', async () => {
10+
const result = await embed({
11+
model: new MockEmbeddingModelV1({
12+
doEmbed: async ({ values }) => {
13+
assert.deepStrictEqual(values, [testValue]);
14+
15+
return {
16+
embeddings: [dummyEmbedding],
17+
};
18+
},
19+
}),
20+
value: testValue,
21+
});
22+
23+
assert.deepStrictEqual(result.embedding, dummyEmbedding);
24+
});
25+
});

Diff for: ‎packages/core/core/embed/embed.ts

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import { Embedding, EmbeddingModel } from '../types';
2+
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
3+
4+
/**
5+
Embed a value using an embedding model. The type of the value is defined by the embedding model.
6+
7+
@param model - The embedding model to use.
8+
@param value - The value that should be embedded.
9+
10+
@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2.
11+
@param abortSignal - An optional abort signal that can be used to cancel the call.
12+
13+
@returns A result object that contains the embedding, the value, and additional information.
14+
*/
15+
export async function embed<VALUE>({
16+
model,
17+
value,
18+
maxRetries,
19+
abortSignal,
20+
}: {
21+
/**
22+
The embedding model to use.
23+
*/
24+
model: EmbeddingModel<VALUE>;
25+
26+
/**
27+
The value that should be embedded.
28+
*/
29+
value: VALUE;
30+
31+
/**
32+
Maximum number of retries per embedding model call. Set to 0 to disable retries.
33+
34+
@default 2
35+
*/
36+
maxRetries?: number;
37+
38+
/**
39+
Abort signal.
40+
*/
41+
abortSignal?: AbortSignal;
42+
}): Promise<EmbedResult<VALUE>> {
43+
const retry = retryWithExponentialBackoff({ maxRetries });
44+
45+
const modelResponse = await retry(() =>
46+
model.doEmbed({
47+
values: [value],
48+
abortSignal,
49+
}),
50+
);
51+
52+
return new EmbedResult({
53+
value,
54+
embedding: modelResponse.embeddings[0],
55+
rawResponse: modelResponse.rawResponse,
56+
});
57+
}
58+
59+
/**
60+
The result of a `embed` call.
61+
It contains the embedding, the value, and additional information.
62+
*/
63+
export class EmbedResult<VALUE> {
64+
/**
65+
The value that was embedded.
66+
*/
67+
readonly value: VALUE;
68+
69+
/**
70+
The embedding of the value.
71+
*/
72+
readonly embedding: Embedding;
73+
74+
/**
75+
Optional raw response data.
76+
*/
77+
readonly rawResponse?: {
78+
/**
79+
Response headers.
80+
*/
81+
headers?: Record<string, string>;
82+
};
83+
84+
constructor(options: {
85+
value: VALUE;
86+
embedding: Embedding;
87+
rawResponse?: {
88+
headers?: Record<string, string>;
89+
};
90+
}) {
91+
this.value = options.value;
92+
this.embedding = options.embedding;
93+
this.rawResponse = options.rawResponse;
94+
}
95+
}

Diff for: ‎packages/core/core/embed/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export * from './embed';

Diff for: ‎packages/core/core/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
export * from './embed';
12
export * from './generate-object';
23
export * from './generate-text';
34
export * from './prompt';

Diff for: ‎packages/core/core/test/mock-embedding-model-v1.ts

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import { EmbeddingModelV1 } from '@ai-sdk/provider';
2+
3+
export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
4+
readonly specificationVersion = 'v1';
5+
6+
readonly provider: EmbeddingModelV1<VALUE>['provider'];
7+
readonly modelId: EmbeddingModelV1<VALUE>['modelId'];
8+
readonly maxEmbeddingsPerCall: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
9+
readonly supportsParallelCalls: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
10+
11+
doEmbed: EmbeddingModelV1<VALUE>['doEmbed'];
12+
13+
constructor({
14+
provider = 'mock-provider',
15+
modelId = 'mock-model-id',
16+
maxEmbeddingsPerCall = 1,
17+
supportsParallelCalls = false,
18+
doEmbed = notImplemented,
19+
}: {
20+
provider?: EmbeddingModelV1<VALUE>['provider'];
21+
modelId?: EmbeddingModelV1<VALUE>['modelId'];
22+
maxEmbeddingsPerCall?: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
23+
supportsParallelCalls?: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
24+
doEmbed?: EmbeddingModelV1<VALUE>['doEmbed'];
25+
}) {
26+
this.provider = provider;
27+
this.modelId = modelId;
28+
this.maxEmbeddingsPerCall = maxEmbeddingsPerCall;
29+
this.supportsParallelCalls = supportsParallelCalls;
30+
this.doEmbed = doEmbed;
31+
}
32+
}
33+
34+
function notImplemented(): never {
35+
throw new Error('Not implemented');
36+
}

Diff for: ‎packages/core/core/types/embedding-model.ts

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { EmbeddingModelV1, EmbeddingModelV1Embedding } from '@ai-sdk/provider';
2+
3+
/**
4+
Embedding model that is used by the AI SDK Core functions.
5+
*/
6+
export type EmbeddingModel<VALUE> = EmbeddingModelV1<VALUE>;
7+
8+
/**
9+
Embedding.
10+
*/
11+
export type Embedding = EmbeddingModelV1Embedding;

Diff for: ‎packages/core/core/types/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
export * from './embedding-model';
12
export * from './errors';
23
export * from './language-model';

Diff for: ‎packages/mistral/src/mistral-embedding-model.test.ts

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import { EmbeddingModelV1Embedding } from '@ai-sdk/provider';
2+
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
3+
import { createMistral } from './mistral-provider';
4+
5+
const dummyEmbeddings = [
6+
[0.1, 0.2, 0.3, 0.4, 0.5],
7+
[0.6, 0.7, 0.8, 0.9, 1.0],
8+
];
9+
const testValues = ['sunny day at the beach', 'rainy day in the city'];
10+
11+
const provider = createMistral({ apiKey: 'test-api-key' });
12+
const model = provider.embedding('mistral-embed');
13+
14+
describe('doEmbed', () => {
15+
const server = new JsonTestServer('https://api.mistral.ai/v1/embeddings');
16+
17+
server.setupTestEnvironment();
18+
19+
function prepareJsonResponse({
20+
embeddings = dummyEmbeddings,
21+
}: {
22+
embeddings?: EmbeddingModelV1Embedding[];
23+
} = {}) {
24+
server.responseBodyJson = {
25+
id: 'b322cfc2b9d34e2f8e14fc99874faee5',
26+
object: 'list',
27+
data: embeddings.map((embedding, i) => ({
28+
object: 'embedding',
29+
embedding,
30+
index: i,
31+
})),
32+
model: 'mistral-embed',
33+
usage: { prompt_tokens: 8, total_tokens: 8, completion_tokens: 0 },
34+
};
35+
}
36+
37+
it('should extract embedding', async () => {
38+
prepareJsonResponse();
39+
40+
const { embeddings } = await model.doEmbed({ values: testValues });
41+
42+
expect(embeddings).toStrictEqual(dummyEmbeddings);
43+
});
44+
45+
it('should expose the raw response headers', async () => {
46+
prepareJsonResponse();
47+
48+
server.responseHeaders = {
49+
'test-header': 'test-value',
50+
};
51+
52+
const { rawResponse } = await model.doEmbed({ values: testValues });
53+
54+
expect(rawResponse?.headers).toStrictEqual({
55+
// default headers:
56+
'content-type': 'application/json',
57+
58+
// custom header
59+
'test-header': 'test-value',
60+
});
61+
});
62+
63+
it('should pass the model and the values', async () => {
64+
prepareJsonResponse();
65+
66+
await model.doEmbed({ values: testValues });
67+
68+
expect(await server.getRequestBodyJson()).toStrictEqual({
69+
model: 'mistral-embed',
70+
input: testValues,
71+
encoding_format: 'float',
72+
});
73+
});
74+
75+
it('should pass custom headers', async () => {
76+
prepareJsonResponse();
77+
78+
const provider = createMistral({
79+
apiKey: 'test-api-key',
80+
headers: {
81+
'Custom-Header': 'test-header',
82+
},
83+
});
84+
85+
await provider.embedding('mistral-embed').doEmbed({
86+
values: testValues,
87+
});
88+
89+
const requestHeaders = await server.getRequestHeaders();
90+
expect(requestHeaders.get('Custom-Header')).toStrictEqual('test-header');
91+
});
92+
93+
it('should pass the api key as Authorization header', async () => {
94+
prepareJsonResponse();
95+
96+
const provider = createMistral({ apiKey: 'test-api-key' });
97+
98+
await provider.embedding('mistral-embed').doEmbed({
99+
values: testValues,
100+
});
101+
102+
expect(
103+
(await server.getRequestHeaders()).get('Authorization'),
104+
).toStrictEqual('Bearer test-api-key');
105+
});
106+
});

Diff for: ‎packages/mistral/src/mistral-embedding-model.ts

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import {
2+
EmbeddingModelV1,
3+
TooManyEmbeddingValuesForCallError,
4+
} from '@ai-sdk/provider';
5+
import {
6+
createJsonResponseHandler,
7+
postJsonToApi,
8+
} from '@ai-sdk/provider-utils';
9+
import { z } from 'zod';
10+
import {
11+
MistralEmbeddingModelId,
12+
MistralEmbeddingSettings,
13+
} from './mistral-embedding-settings';
14+
import { mistralFailedResponseHandler } from './mistral-error';
15+
16+
type MistralEmbeddingConfig = {
17+
provider: string;
18+
baseURL: string;
19+
headers: () => Record<string, string | undefined>;
20+
};
21+
22+
export class MistralEmbeddingModel implements EmbeddingModelV1<string> {
23+
readonly specificationVersion = 'v1';
24+
readonly modelId: MistralEmbeddingModelId;
25+
26+
private readonly config: MistralEmbeddingConfig;
27+
private readonly settings: MistralEmbeddingSettings;
28+
29+
get provider(): string {
30+
return this.config.provider;
31+
}
32+
33+
get maxEmbeddingsPerCall(): number {
34+
return this.settings.maxEmbeddingsPerCall ?? 32;
35+
}
36+
37+
get supportsParallelCalls(): boolean {
38+
// Parallel calls are technically possible,
39+
// but I have been hitting rate limits and disable them for now.
40+
return this.settings.supportsParallelCalls ?? false;
41+
}
42+
43+
constructor(
44+
modelId: MistralEmbeddingModelId,
45+
settings: MistralEmbeddingSettings,
46+
config: MistralEmbeddingConfig,
47+
) {
48+
this.modelId = modelId;
49+
this.settings = settings;
50+
this.config = config;
51+
}
52+
53+
async doEmbed({
54+
values,
55+
abortSignal,
56+
}: Parameters<EmbeddingModelV1<string>['doEmbed']>[0]): Promise<
57+
Awaited<ReturnType<EmbeddingModelV1<string>['doEmbed']>>
58+
> {
59+
if (values.length > this.maxEmbeddingsPerCall) {
60+
throw new TooManyEmbeddingValuesForCallError({
61+
provider: this.provider,
62+
modelId: this.modelId,
63+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
64+
values,
65+
});
66+
}
67+
68+
const { responseHeaders, value: response } = await postJsonToApi({
69+
url: `${this.config.baseURL}/embeddings`,
70+
headers: this.config.headers(),
71+
body: {
72+
model: this.modelId,
73+
input: values,
74+
encoding_format: 'float',
75+
},
76+
failedResponseHandler: mistralFailedResponseHandler,
77+
successfulResponseHandler: createJsonResponseHandler(
78+
MistralTextEmbeddingResponseSchema,
79+
),
80+
abortSignal,
81+
});
82+
83+
return {
84+
embeddings: response.data.map(item => item.embedding),
85+
rawResponse: { headers: responseHeaders },
86+
};
87+
}
88+
}
89+
90+
// minimal version of the schema, focussed on what is needed for the implementation
91+
// this approach limits breakages when the API changes and increases efficiency
92+
const MistralTextEmbeddingResponseSchema = z.object({
93+
data: z.array(
94+
z.object({
95+
embedding: z.array(z.number()),
96+
}),
97+
),
98+
});

Diff for: ‎packages/mistral/src/mistral-embedding-settings.ts

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
export type MistralEmbeddingModelId = 'mistral-embed' | (string & {});
2+
3+
export interface MistralEmbeddingSettings {
4+
/**
5+
Override the maximum number of embeddings per call.
6+
*/
7+
maxEmbeddingsPerCall?: number;
8+
9+
/**
10+
Override the parallelism of embedding calls.
11+
*/
12+
supportsParallelCalls?: boolean;
13+
}

Diff for: ‎packages/mistral/src/mistral-provider.ts

+45-14
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,33 @@ import {
88
MistralChatModelId,
99
MistralChatSettings,
1010
} from './mistral-chat-settings';
11+
import {
12+
MistralEmbeddingModelId,
13+
MistralEmbeddingSettings,
14+
} from './mistral-embedding-settings';
15+
import { MistralEmbeddingModel } from './mistral-embedding-model';
1116

1217
export interface MistralProvider {
1318
(
1419
modelId: MistralChatModelId,
1520
settings?: MistralChatSettings,
1621
): MistralChatLanguageModel;
1722

23+
/**
24+
Creates a model for text generation.
25+
*/
1826
chat(
1927
modelId: MistralChatModelId,
2028
settings?: MistralChatSettings,
2129
): MistralChatLanguageModel;
30+
31+
/**
32+
Creates a model for text embeddings.
33+
*/
34+
embedding(
35+
modelId: MistralEmbeddingModelId,
36+
settings?: MistralEmbeddingSettings,
37+
): MistralEmbeddingModel;
2238
}
2339

2440
export interface MistralProviderSettings {
@@ -53,26 +69,40 @@ Create a Mistral AI provider instance.
5369
export function createMistral(
5470
options: MistralProviderSettings = {},
5571
): MistralProvider {
56-
const createModel = (
72+
const baseURL =
73+
withoutTrailingSlash(options.baseURL ?? options.baseUrl) ??
74+
'https://api.mistral.ai/v1';
75+
76+
const getHeaders = () => ({
77+
Authorization: `Bearer ${loadApiKey({
78+
apiKey: options.apiKey,
79+
environmentVariableName: 'MISTRAL_API_KEY',
80+
description: 'Mistral',
81+
})}`,
82+
...options.headers,
83+
});
84+
85+
const createChatModel = (
5786
modelId: MistralChatModelId,
5887
settings: MistralChatSettings = {},
5988
) =>
6089
new MistralChatLanguageModel(modelId, settings, {
6190
provider: 'mistral.chat',
62-
baseURL:
63-
withoutTrailingSlash(options.baseURL ?? options.baseUrl) ??
64-
'https://api.mistral.ai/v1',
65-
headers: () => ({
66-
Authorization: `Bearer ${loadApiKey({
67-
apiKey: options.apiKey,
68-
environmentVariableName: 'MISTRAL_API_KEY',
69-
description: 'Mistral',
70-
})}`,
71-
...options.headers,
72-
}),
91+
baseURL,
92+
headers: getHeaders,
7393
generateId: options.generateId ?? generateId,
7494
});
7595

96+
const createEmbeddingModel = (
97+
modelId: MistralEmbeddingModelId,
98+
settings: MistralEmbeddingSettings = {},
99+
) =>
100+
new MistralEmbeddingModel(modelId, settings, {
101+
provider: 'mistral.embedding',
102+
baseURL,
103+
headers: getHeaders,
104+
});
105+
76106
const provider = function (
77107
modelId: MistralChatModelId,
78108
settings?: MistralChatSettings,
@@ -83,10 +113,11 @@ export function createMistral(
83113
);
84114
}
85115

86-
return createModel(modelId, settings);
116+
return createChatModel(modelId, settings);
87117
};
88118

89-
provider.chat = createModel;
119+
provider.chat = createChatModel;
120+
provider.embedding = createEmbeddingModel;
90121

91122
return provider as MistralProvider;
92123
}

Diff for: ‎packages/openai/src/openai-embedding-model.test.ts

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import { EmbeddingModelV1Embedding } from '@ai-sdk/provider';
2+
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
3+
import { createOpenAI } from './openai-provider';
4+
5+
const dummyEmbeddings = [
6+
[0.1, 0.2, 0.3, 0.4, 0.5],
7+
[0.6, 0.7, 0.8, 0.9, 1.0],
8+
];
9+
const testValues = ['sunny day at the beach', 'rainy day in the city'];
10+
11+
const provider = createOpenAI({ apiKey: 'test-api-key' });
12+
const model = provider.embedding('text-embedding-3-large');
13+
14+
describe('doEmbed', () => {
15+
const server = new JsonTestServer('https://api.openai.com/v1/embeddings');
16+
17+
server.setupTestEnvironment();
18+
19+
function prepareJsonResponse({
20+
embeddings = dummyEmbeddings,
21+
}: {
22+
embeddings?: EmbeddingModelV1Embedding[];
23+
} = {}) {
24+
server.responseBodyJson = {
25+
object: 'list',
26+
data: embeddings.map((embedding, i) => ({
27+
object: 'embedding',
28+
index: i,
29+
embedding,
30+
})),
31+
model: 'text-embedding-3-large',
32+
usage: { prompt_tokens: 8, total_tokens: 8 },
33+
};
34+
}
35+
36+
it('should extract embedding', async () => {
37+
prepareJsonResponse();
38+
39+
const { embeddings } = await model.doEmbed({ values: testValues });
40+
41+
expect(embeddings).toStrictEqual(dummyEmbeddings);
42+
});
43+
44+
it('should expose the raw response headers', async () => {
45+
prepareJsonResponse();
46+
47+
server.responseHeaders = {
48+
'test-header': 'test-value',
49+
};
50+
51+
const { rawResponse } = await model.doEmbed({ values: testValues });
52+
53+
expect(rawResponse?.headers).toStrictEqual({
54+
// default headers:
55+
'content-type': 'application/json',
56+
57+
// custom header
58+
'test-header': 'test-value',
59+
});
60+
});
61+
62+
it('should pass the model and the values', async () => {
63+
prepareJsonResponse();
64+
65+
await model.doEmbed({ values: testValues });
66+
67+
expect(await server.getRequestBodyJson()).toStrictEqual({
68+
model: 'text-embedding-3-large',
69+
input: testValues,
70+
encoding_format: 'float',
71+
});
72+
});
73+
74+
it('should pass the dimensions setting', async () => {
75+
prepareJsonResponse();
76+
77+
await provider
78+
.embedding('text-embedding-3-large', { dimensions: 64 })
79+
.doEmbed({ values: testValues });
80+
81+
expect(await server.getRequestBodyJson()).toStrictEqual({
82+
model: 'text-embedding-3-large',
83+
input: testValues,
84+
encoding_format: 'float',
85+
dimensions: 64,
86+
});
87+
});
88+
89+
it('should pass custom headers', async () => {
90+
prepareJsonResponse();
91+
92+
const provider = createOpenAI({
93+
apiKey: 'test-api-key',
94+
organization: 'test-organization',
95+
project: 'test-project',
96+
headers: {
97+
'Custom-Header': 'test-header',
98+
},
99+
});
100+
101+
await provider.embedding('text-embedding-3-large').doEmbed({
102+
values: testValues,
103+
});
104+
105+
const requestHeaders = await server.getRequestHeaders();
106+
107+
expect(requestHeaders.get('OpenAI-Organization')).toStrictEqual(
108+
'test-organization',
109+
);
110+
expect(requestHeaders.get('OpenAI-Project')).toStrictEqual('test-project');
111+
expect(requestHeaders.get('Custom-Header')).toStrictEqual('test-header');
112+
});
113+
114+
it('should pass the api key as Authorization header', async () => {
115+
prepareJsonResponse();
116+
117+
const provider = createOpenAI({ apiKey: 'test-api-key' });
118+
119+
await provider.embedding('text-embedding-3-large').doEmbed({
120+
values: testValues,
121+
});
122+
123+
expect(
124+
(await server.getRequestHeaders()).get('Authorization'),
125+
).toStrictEqual('Bearer test-api-key');
126+
});
127+
});

Diff for: ‎packages/openai/src/openai-embedding-model.ts

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import {
2+
EmbeddingModelV1,
3+
TooManyEmbeddingValuesForCallError,
4+
} from '@ai-sdk/provider';
5+
import {
6+
createJsonResponseHandler,
7+
postJsonToApi,
8+
} from '@ai-sdk/provider-utils';
9+
import { z } from 'zod';
10+
import {
11+
OpenAIEmbeddingModelId,
12+
OpenAIEmbeddingSettings,
13+
} from './openai-embedding-settings';
14+
import { openaiFailedResponseHandler } from './openai-error';
15+
16+
type OpenAIEmbeddingConfig = {
17+
provider: string;
18+
baseURL: string;
19+
headers: () => Record<string, string | undefined>;
20+
};
21+
22+
export class OpenAIEmbeddingModel implements EmbeddingModelV1<string> {
23+
readonly specificationVersion = 'v1';
24+
readonly modelId: OpenAIEmbeddingModelId;
25+
26+
private readonly config: OpenAIEmbeddingConfig;
27+
private readonly settings: OpenAIEmbeddingSettings;
28+
29+
get provider(): string {
30+
return this.config.provider;
31+
}
32+
33+
get maxEmbeddingsPerCall(): number {
34+
return this.settings.maxEmbeddingsPerCall ?? 2048;
35+
}
36+
37+
get supportsParallelCalls(): boolean {
38+
return this.settings.supportsParallelCalls ?? true;
39+
}
40+
41+
constructor(
42+
modelId: OpenAIEmbeddingModelId,
43+
settings: OpenAIEmbeddingSettings,
44+
config: OpenAIEmbeddingConfig,
45+
) {
46+
this.modelId = modelId;
47+
this.settings = settings;
48+
this.config = config;
49+
}
50+
51+
async doEmbed({
52+
values,
53+
abortSignal,
54+
}: Parameters<EmbeddingModelV1<string>['doEmbed']>[0]): Promise<
55+
Awaited<ReturnType<EmbeddingModelV1<string>['doEmbed']>>
56+
> {
57+
if (values.length > this.maxEmbeddingsPerCall) {
58+
throw new TooManyEmbeddingValuesForCallError({
59+
provider: this.provider,
60+
modelId: this.modelId,
61+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
62+
values,
63+
});
64+
}
65+
66+
const { responseHeaders, value: response } = await postJsonToApi({
67+
url: `${this.config.baseURL}/embeddings`,
68+
headers: this.config.headers(),
69+
body: {
70+
model: this.modelId,
71+
input: values,
72+
encoding_format: 'float',
73+
dimensions: this.settings.dimensions,
74+
user: this.settings.user,
75+
},
76+
failedResponseHandler: openaiFailedResponseHandler,
77+
successfulResponseHandler: createJsonResponseHandler(
78+
openaiTextEmbeddingResponseSchema,
79+
),
80+
abortSignal,
81+
});
82+
83+
return {
84+
embeddings: response.data.map(item => item.embedding),
85+
rawResponse: { headers: responseHeaders },
86+
};
87+
}
88+
}
89+
90+
// minimal version of the schema, focussed on what is needed for the implementation
91+
// this approach limits breakages when the API changes and increases efficiency
92+
const openaiTextEmbeddingResponseSchema = z.object({
93+
data: z.array(
94+
z.object({
95+
embedding: z.array(z.number()),
96+
}),
97+
),
98+
});

Diff for: ‎packages/openai/src/openai-embedding-settings.ts

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
export type OpenAIEmbeddingModelId =
2+
| 'text-embedding-3-small'
3+
| 'text-embedding-3-large'
4+
| 'text-embedding-ada-002'
5+
| (string & {});
6+
7+
export interface OpenAIEmbeddingSettings {
8+
/**
9+
Override the maximum number of embeddings per call.
10+
*/
11+
maxEmbeddingsPerCall?: number;
12+
13+
/**
14+
Override the parallelism of embedding calls.
15+
*/
16+
supportsParallelCalls?: boolean;
17+
18+
/**
19+
The number of dimensions the resulting output embeddings should have.
20+
Only supported in text-embedding-3 and later models.
21+
*/
22+
dimensions?: number;
23+
24+
/**
25+
A unique identifier representing your end-user, which can help OpenAI to
26+
monitor and detect abuse. Learn more.
27+
*/
28+
user?: string;
29+
}

Diff for: ‎packages/openai/src/openai-provider.ts

+74-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils';
12
import { OpenAIChatLanguageModel } from './openai-chat-language-model';
23
import { OpenAIChatModelId, OpenAIChatSettings } from './openai-chat-settings';
34
import { OpenAICompletionLanguageModel } from './openai-completion-language-model';
45
import {
56
OpenAICompletionModelId,
67
OpenAICompletionSettings,
78
} from './openai-completion-settings';
8-
import { OpenAI } from './openai-facade';
9+
import { OpenAIEmbeddingModel } from './openai-embedding-model';
10+
import {
11+
OpenAIEmbeddingModelId,
12+
OpenAIEmbeddingSettings,
13+
} from './openai-embedding-settings';
914

1015
export interface OpenAIProvider {
1116
(
@@ -17,15 +22,29 @@ export interface OpenAIProvider {
1722
settings?: OpenAIChatSettings,
1823
): OpenAIChatLanguageModel;
1924

25+
/**
26+
Creates an OpenAI chat model for text generation.
27+
*/
2028
chat(
2129
modelId: OpenAIChatModelId,
2230
settings?: OpenAIChatSettings,
2331
): OpenAIChatLanguageModel;
2432

33+
/**
34+
Creates an OpenAI completion model for text generation.
35+
*/
2536
completion(
2637
modelId: OpenAICompletionModelId,
2738
settings?: OpenAICompletionSettings,
2839
): OpenAICompletionLanguageModel;
40+
41+
/**
42+
Creates a model for text embeddings.
43+
*/
44+
embedding(
45+
modelId: OpenAIEmbeddingModelId,
46+
settings?: OpenAIEmbeddingSettings,
47+
): OpenAIEmbeddingModel;
2948
}
3049

3150
export interface OpenAIProviderSettings {
@@ -66,7 +85,50 @@ Create an OpenAI provider instance.
6685
export function createOpenAI(
6786
options: OpenAIProviderSettings = {},
6887
): OpenAIProvider {
69-
const openai = new OpenAI(options);
88+
const baseURL =
89+
withoutTrailingSlash(options.baseURL ?? options.baseUrl) ??
90+
'https://api.openai.com/v1';
91+
92+
const getHeaders = () => ({
93+
Authorization: `Bearer ${loadApiKey({
94+
apiKey: options.apiKey,
95+
environmentVariableName: 'OPENAI_API_KEY',
96+
description: 'OpenAI',
97+
})}`,
98+
'OpenAI-Organization': options.organization,
99+
'OpenAI-Project': options.project,
100+
...options.headers,
101+
});
102+
103+
const createChatModel = (
104+
modelId: OpenAIChatModelId,
105+
settings: OpenAIChatSettings = {},
106+
) =>
107+
new OpenAIChatLanguageModel(modelId, settings, {
108+
provider: 'openai.chat',
109+
baseURL,
110+
headers: getHeaders,
111+
});
112+
113+
const createCompletionModel = (
114+
modelId: OpenAICompletionModelId,
115+
settings: OpenAICompletionSettings = {},
116+
) =>
117+
new OpenAICompletionLanguageModel(modelId, settings, {
118+
provider: 'openai.completion',
119+
baseURL,
120+
headers: getHeaders,
121+
});
122+
123+
const createEmbeddingModel = (
124+
modelId: OpenAIEmbeddingModelId,
125+
settings: OpenAIEmbeddingSettings = {},
126+
) =>
127+
new OpenAIEmbeddingModel(modelId, settings, {
128+
provider: 'openai.embedding',
129+
baseURL,
130+
headers: getHeaders,
131+
});
70132

71133
const provider = function (
72134
modelId: OpenAIChatModelId | OpenAICompletionModelId,
@@ -79,19 +141,23 @@ export function createOpenAI(
79141
}
80142

81143
if (modelId === 'gpt-3.5-turbo-instruct') {
82-
return openai.completion(modelId, settings as OpenAICompletionSettings);
83-
} else {
84-
return openai.chat(modelId, settings as OpenAIChatSettings);
144+
return createCompletionModel(
145+
modelId,
146+
settings as OpenAICompletionSettings,
147+
);
85148
}
149+
150+
return createChatModel(modelId, settings as OpenAIChatSettings);
86151
};
87152

88-
provider.chat = openai.chat.bind(openai);
89-
provider.completion = openai.completion.bind(openai);
153+
provider.chat = createChatModel;
154+
provider.completion = createCompletionModel;
155+
provider.embedding = createEmbeddingModel;
90156

91157
return provider as OpenAIProvider;
92158
}
93159

94160
/**
95-
* Default OpenAI provider instance.
161+
Default OpenAI provider instance.
96162
*/
97163
export const openai = createOpenAI();

Diff for: ‎packages/provider/src/embedding-model/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export * from './v1/index';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
/**
2+
An embedding is a vector, i.e. an array of numbers.
3+
It is e.g. used to represent a text as a vector of word embeddings.
4+
*/
5+
export type EmbeddingModelV1Embedding = Array<number>;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import { EmbeddingModelV1Embedding } from './embedding-model-v1-embedding';
2+
3+
/**
4+
Experimental: Specification for an embedding model that implements the embedding model
5+
interface version 1.
6+
7+
VALUE is the type of the values that the model can embed.
8+
This will allow us to go beyond text embeddings in the future,
9+
e.g. to support image embeddings
10+
*/
11+
export type EmbeddingModelV1<VALUE> = {
12+
/**
13+
The embedding model must specify which embedding model interface
14+
version it implements. This will allow us to evolve the embedding
15+
model interface and retain backwards compatibility. The different
16+
implementation versions can be handled as a discriminated union
17+
on our side.
18+
*/
19+
readonly specificationVersion: 'v1';
20+
21+
/**
22+
Name of the provider for logging purposes.
23+
*/
24+
readonly provider: string;
25+
26+
/**
27+
Provider-specific model ID for logging purposes.
28+
*/
29+
readonly modelId: string;
30+
31+
/**
32+
Limit of how many embeddings can be generated in a single API call.
33+
*/
34+
readonly maxEmbeddingsPerCall: number | undefined;
35+
36+
/**
37+
True if the model can handle multiple embedding calls in parallel.
38+
*/
39+
readonly supportsParallelCalls: boolean;
40+
41+
/**
42+
Generates a list of embeddings for the given input text.
43+
44+
Naming: "do" prefix to prevent accidental direct usage of the method
45+
by the user.
46+
*/
47+
doEmbed(options: {
48+
/**
49+
List of values to embed.
50+
*/
51+
values: Array<VALUE>;
52+
53+
/**
54+
Abort signal for cancelling the operation.
55+
*/
56+
abortSignal?: AbortSignal;
57+
}): PromiseLike<{
58+
/**
59+
Generated embeddings. They are in the same order as the input values.
60+
*/
61+
embeddings: Array<EmbeddingModelV1Embedding>;
62+
63+
/**
64+
Optional raw response information for debugging purposes.
65+
*/
66+
rawResponse?: {
67+
/**
68+
Response headers.
69+
*/
70+
headers?: Record<string, string>;
71+
};
72+
}>;
73+
};

Diff for: ‎packages/provider/src/embedding-model/v1/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from './embedding-model-v1';
2+
export * from './embedding-model-v1-embedding';

Diff for: ‎packages/provider/src/errors/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export * from './load-api-key-error';
1010
export * from './no-object-generated-error';
1111
export * from './no-such-tool-error';
1212
export * from './retry-error';
13+
export * from './too-many-embedding-values-for-call-error';
1314
export * from './tool-call-parse-error';
1415
export * from './type-validation-error';
1516
export * from './unsupported-functionality-error';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
export class TooManyEmbeddingValuesForCallError extends Error {
2+
readonly provider: string;
3+
readonly modelId: string;
4+
readonly maxEmbeddingsPerCall: number;
5+
readonly values: Array<unknown>;
6+
7+
constructor(options: {
8+
provider: string;
9+
modelId: string;
10+
maxEmbeddingsPerCall: number;
11+
values: Array<unknown>;
12+
}) {
13+
super(
14+
`Too many values for a single embedding call. ` +
15+
`The ${options.provider} model "${options.modelId}" can only embed up to ` +
16+
`${options.maxEmbeddingsPerCall} values per call, but ${options.values.length} values were provided.`,
17+
);
18+
19+
this.name = 'AI_TooManyEmbeddingValuesForCallError';
20+
21+
this.provider = options.provider;
22+
this.modelId = options.modelId;
23+
this.maxEmbeddingsPerCall = options.maxEmbeddingsPerCall;
24+
this.values = options.values;
25+
}
26+
27+
static isInvalidPromptError(
28+
error: unknown,
29+
): error is TooManyEmbeddingValuesForCallError {
30+
return (
31+
error instanceof Error &&
32+
error.name === 'AI_TooManyEmbeddingValuesForCallError' &&
33+
'provider' in error &&
34+
typeof error.provider === 'string' &&
35+
'modelId' in error &&
36+
typeof error.modelId === 'string' &&
37+
'maxEmbeddingsPerCall' in error &&
38+
typeof error.maxEmbeddingsPerCall === 'number' &&
39+
'values' in error &&
40+
Array.isArray(error.values)
41+
);
42+
}
43+
44+
toJSON() {
45+
return {
46+
name: this.name,
47+
message: this.message,
48+
stack: this.stack,
49+
50+
provider: this.provider,
51+
modelId: this.modelId,
52+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
53+
values: this.values,
54+
};
55+
}
56+
}

Diff for: ‎packages/provider/src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
export * from './embedding-model/index';
12
export * from './errors/index';
23
export * from './language-model/index';

0 commit comments

Comments
 (0)
Please sign in to comment.