Skip to content

Commit 3a21030

Browse files
authoredMay 16, 2024··
feat (ai/core): add embedMany function (#1617)
1 parent 339aafa commit 3a21030

File tree

17 files changed

+444
-24
lines changed

17 files changed

+444
-24
lines changed
 

Diff for: ‎.changeset/five-knives-deny.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai/core): add embedMany function

Diff for: ‎content/docs/03-ai-sdk-core/30-embeddings.mdx

+29-2
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,42 @@ In this space, similar words are close to each other, and the distance between w
1010

1111
## Embedding a Single Value
1212

13-
The Vercel AI SDK provides the `embed` function to embed single values, which is useful for tasks such as finding similar words
14-
or phrases or clustering text. You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.
13+
The Vercel AI SDK provides the [`embed`](/docs/reference/ai-sdk-core/embed) function to embed single values, which is useful for tasks such as finding similar words
14+
or phrases or clustering text.
15+
You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.
1516

1617
```tsx
1718
import { embed } from 'ai';
1819
import { openai } from '@ai-sdk/openai';
1920

21+
// 'embedding' is a single embedding object (number[])
2022
const { embedding } = await embed({
2123
model: openai.embedding('text-embedding-3-small'),
2224
value: 'sunny day at the beach',
2325
});
2426
```
27+
28+
## Embedding Many Values
29+
30+
When loading data, e.g. when preparing a data store for retrieval-augmented generation (RAG),
31+
it is often useful to embed many values at once (batch embedding).
32+
33+
The Vercel AI SDK provides the `embedMany` function for this purpose.
34+
Similar to `embed`, you can use it with embeddings models,
35+
e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.
36+
37+
```tsx
38+
import { openai } from '@ai-sdk/openai';
39+
import { embedMany } from 'ai';
40+
41+
// 'embeddings' is an array of embedding objects (number[][]).
42+
// It is sorted in the same order as the input values.
43+
const { embeddings } = await embedMany({
44+
model: openai.embedding('text-embedding-3-small'),
45+
values: [
46+
'sunny day at the beach',
47+
'rainy afternoon in the city',
48+
'snowy night in the mountains',
49+
],
50+
});
51+
```
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
title: embedMany
3+
description: Embed several values using the AI SDK Core (batch embedding)
4+
---
5+
6+
# `embedMany`
7+
8+
Embed several values using an embedding model. The type of the value is defined
9+
by the embedding model.
10+
11+
`embedMany` automatically splits large requests into smaller chunks if the model
12+
has a limit on how many embeddings can be generated in a single call.
13+
14+
## Import
15+
16+
<Snippet text={`import { embedMany } from "ai"`} prompt={false} />
17+
18+
<ReferenceTable packageName="core" functionName="embedMany" />

Diff for: ‎content/docs/07-reference/ai-sdk-core/index.mdx

+6
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,11 @@ description: Reference documentation for the AI SDK Core
3333
'Generate an embedding for a single value using an embedding model.',
3434
href: '/docs/reference/ai-sdk-core/embed',
3535
},
36+
{
37+
title: 'embedMany',
38+
description:
39+
'Generate embeddings for several values using an embedding model (batch embedding).',
40+
href: '/docs/reference/ai-sdk-core/embed-many',
41+
},
3642
]}
3743
/>

Diff for: ‎content/providers/01-ai-sdk-providers/01-openai.mdx

+46-8
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ You can use the following optional settings to customize the OpenAI provider ins
7272
and `compatible` when using 3rd party providers. In `compatible` mode, newer
7373
information such as streamOptions are not being sent. Defaults to 'compatible'.
7474

75-
## Models
75+
## Language Models
7676

77-
The OpenAI provider instance is a function that you can invoke to create a model:
77+
The OpenAI provider instance is a function that you can invoke to create a language model:
7878

7979
```ts
8080
const model = openai('gpt-3.5-turbo');
@@ -92,6 +92,14 @@ const model = openai('gpt-3.5-turbo', {
9292
The available options depend on the API that's automatically chosen for the model (see below).
9393
If you want to explicitly select a specific model API, you can use `.chat` or `.completion`.
9494

95+
### Model Capabilities
96+
97+
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
98+
| --------------- | ------------------- | ------------------- | ------------------- | ------------------- |
99+
| `gpt-4-turbo` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
100+
| `gpt-4` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
101+
| `gpt-3.5-turbo` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
102+
95103
### Chat Models
96104

97105
You can create models that call the [OpenAI chat API](https://platform.openai.com/docs/api-reference/chat) using the `.chat()` factory method.
@@ -215,10 +223,40 @@ The following optional settings are available for OpenAI completion models:
215223
A unique identifier representing your end-user, which can help OpenAI to
216224
monitor and detect abuse. Learn more.
217225

218-
## Model Capabilities
226+
## Embedding Models
219227

220-
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
221-
| --------------- | ------------------- | ------------------- | ------------------- | ------------------- |
222-
| `gpt-4-turbo` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
223-
| `gpt-4` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
224-
| `gpt-3.5-turbo` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
228+
You can create models that call the [OpenAI embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
229+
using the `.embedding()` factory method.
230+
231+
```ts
232+
const model = openai.embedding('text-embedding-3-large');
233+
```
234+
235+
OpenAI embedding models support several aditional settings.
236+
You can pass them as an options argument:
237+
238+
```ts
239+
const model = openai.embedding('text-embedding-3-large', {
240+
dimensions: 512 // optional, number of dimensions for the embedding
241+
user: 'test-user' // optional unique user identifier
242+
})
243+
```
244+
245+
The following optional settings are available for OpenAI embedding models:
246+
247+
- **dimensions**: _number_
248+
249+
Echo back the prompt in addition to the completion.
250+
251+
- **user** _string_
252+
253+
A unique identifier representing your end-user, which can help OpenAI to
254+
monitor and detect abuse. Learn more.
255+
256+
### Model Capabilities
257+
258+
| Model | Default Dimensions | Custom Dimensions |
259+
| ------------------------ | ------------------ | ------------------- |
260+
| `text-embedding-3-large` | 3072 | <Check size={18} /> |
261+
| `text-embedding-3-small` | 1536 | <Check size={18} /> |
262+
| `text-embedding-ada-002` | 1536 | <Cross size={18} /> |

Diff for: ‎content/providers/01-ai-sdk-providers/02-anthropic.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ You can use the following optional settings to customize the Google Generative A
6060

6161
Custom headers to include in the requests.
6262

63-
## Models
63+
## Language Models
6464

6565
You can create models that call the [Anthropic Messages API](https://docs.anthropic.com/claude/reference/messages_post) using the provider instance.
6666
The first argument is the model id, e.g. `claude-3-haiku-20240307`.
@@ -88,7 +88,7 @@ The following optional settings are available for Anthropic models:
8888
Used to remove "long tail" low probability responses.
8989
Recommended for advanced use cases only. You usually only need to use temperature.
9090

91-
## Model Capabilities
91+
### Model Capabilities
9292

9393
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
9494
| -------------------------- | ------------------- | ------------------- | ------------------- | ------------------- |

Diff for: ‎content/providers/01-ai-sdk-providers/03-google-generative-ai.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ You can use the following optional settings to customize the Google Generative A
5858

5959
Custom headers to include in the requests.
6060

61-
## Models
61+
## Language Models
6262

6363
You can create models that call the [Google Generative AI API](https://ai.google.dev/api/rest) using the provider instance.
6464
The first argument is the model id, e.g. `models/gemini-pro`.
@@ -87,7 +87,7 @@ The following optional settings are available for Google Generative AI models:
8787
Top-k sampling considers the set of topK most probable tokens.
8888
Models running with nucleus sampling don't allow topK setting.
8989

90-
## Model Capabilities
90+
### Model Capabilities
9191

9292
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
9393
| ------------------------------ | ------------------- | ------------------- | ------------------- | ------------------- |

Diff for: ‎content/providers/01-ai-sdk-providers/04-mistral.mdx

+12-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The Mistral provider is available in the `@ai-sdk/mistral` module. You can insta
2323
<Snippet text="yarn add @ai-sdk/mistral" dark />
2424
</Tab>
2525
</Tabs>
26+
2627
## Provider Instance
2728

2829
You can import the default provider instance `mistral` from `@ai-sdk/mistral`:
@@ -58,7 +59,7 @@ You can use the following optional settings to customize the Mistral provider in
5859

5960
Custom headers to include in the requests.
6061

61-
## Models
62+
## Language Models
6263

6364
You can create models that call the [Mistral chat API](https://docs.mistral.ai/api/#operation/createChatCompletion) using provider instance.
6465
The first argument is the model id, e.g. `mistral-large-latest`.
@@ -85,9 +86,18 @@ The following optional settings are available for Mistral models:
8586

8687
Defaults to `false`.
8788

88-
## Model Capabilities
89+
### Model Capabilities
8990

9091
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
9192
| ---------------------- | ------------------- | ------------------- | ------------------- | ------------------- |
9293
| `mistral-large-latest` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
9394
| `mistral-small-latest` | <Cross size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
95+
96+
## Embedding Models
97+
98+
You can create models that call the [Mistral embeddings API](https://docs.mistral.ai/api/#operation/createEmbedding)
99+
using the `.embedding()` factory method.
100+
101+
```ts
102+
const model = mistral.embedding('mistral-embed');
103+
```

Diff for: ‎examples/ai-core/src/embed-many/mistral.ts

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { mistral } from '@ai-sdk/mistral';
2+
import { embedMany } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embeddings } = await embedMany({
9+
model: mistral.embedding('mistral-embed'),
10+
values: [
11+
'sunny day at the beach',
12+
'rainy afternoon in the city',
13+
'snowy night in the mountains',
14+
],
15+
});
16+
17+
console.log(embeddings);
18+
}
19+
20+
main().catch(console.error);

Diff for: ‎examples/ai-core/src/embed-many/openai.ts

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { embedMany } from 'ai';
3+
import dotenv from 'dotenv';
4+
5+
dotenv.config();
6+
7+
async function main() {
8+
const { embeddings } = await embedMany({
9+
model: openai.embedding('text-embedding-3-small'),
10+
values: [
11+
'sunny day at the beach',
12+
'rainy afternoon in the city',
13+
'snowy night in the mountains',
14+
],
15+
});
16+
17+
console.log(embeddings);
18+
}
19+
20+
main().catch(console.error);

Diff for: ‎packages/core/core/embed/embed-many.test.ts

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import assert from 'node:assert';
2+
import {
3+
MockEmbeddingModelV1,
4+
mockEmbed,
5+
} from '../test/mock-embedding-model-v1';
6+
import { embedMany } from './embed-many';
7+
8+
const dummyEmbeddings = [
9+
[0.1, 0.2, 0.3],
10+
[0.4, 0.5, 0.6],
11+
[0.7, 0.8, 0.9],
12+
];
13+
14+
const testValues = [
15+
'sunny day at the beach',
16+
'rainy afternoon in the city',
17+
'snowy night in the mountains',
18+
];
19+
20+
describe('result.embedding', () => {
21+
it('should generate embeddings', async () => {
22+
const result = await embedMany({
23+
model: new MockEmbeddingModelV1({
24+
maxEmbeddingsPerCall: 5,
25+
doEmbed: mockEmbed(testValues, dummyEmbeddings),
26+
}),
27+
values: testValues,
28+
});
29+
30+
assert.deepStrictEqual(result.embeddings, dummyEmbeddings);
31+
});
32+
33+
it('should generate embeddings when several calls are required', async () => {
34+
let callCount = 0;
35+
36+
const result = await embedMany({
37+
model: new MockEmbeddingModelV1({
38+
maxEmbeddingsPerCall: 2,
39+
doEmbed: async ({ values }) => {
40+
if (callCount === 0) {
41+
assert.deepStrictEqual(values, testValues.slice(0, 2));
42+
callCount++;
43+
return { embeddings: dummyEmbeddings.slice(0, 2) };
44+
}
45+
46+
if (callCount === 1) {
47+
assert.deepStrictEqual(values, testValues.slice(2));
48+
callCount++;
49+
return { embeddings: dummyEmbeddings.slice(2) };
50+
}
51+
52+
throw new Error('Unexpected call');
53+
},
54+
}),
55+
values: testValues,
56+
});
57+
58+
assert.deepStrictEqual(result.embeddings, dummyEmbeddings);
59+
});
60+
});
61+
62+
describe('result.values', () => {
63+
it('should include values in the result', async () => {
64+
const result = await embedMany({
65+
model: new MockEmbeddingModelV1({
66+
maxEmbeddingsPerCall: 5,
67+
doEmbed: mockEmbed(testValues, dummyEmbeddings),
68+
}),
69+
values: testValues,
70+
});
71+
72+
assert.deepStrictEqual(result.values, testValues);
73+
});
74+
});

Diff for: ‎packages/core/core/embed/embed-many.ts

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import { Embedding, EmbeddingModel } from '../types';
2+
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
3+
import { splitArray } from '../util/split-array';
4+
5+
/**
6+
Embed several values using an embedding model. The type of the value is defined
7+
by the embedding model.
8+
9+
`embedMany` automatically splits large requests into smaller chunks if the model
10+
has a limit on how many embeddings can be generated in a single call.
11+
12+
@param model - The embedding model to use.
13+
@param values - The values that should be embedded.
14+
15+
@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2.
16+
@param abortSignal - An optional abort signal that can be used to cancel the call.
17+
18+
@returns A result object that contains the embeddings, the value, and additional information.
19+
*/
20+
export async function embedMany<VALUE>({
21+
model,
22+
values,
23+
maxRetries,
24+
abortSignal,
25+
}: {
26+
/**
27+
The embedding model to use.
28+
*/
29+
model: EmbeddingModel<VALUE>;
30+
31+
/**
32+
The values that should be embedded.
33+
*/
34+
values: Array<VALUE>;
35+
36+
/**
37+
Maximum number of retries per embedding model call. Set to 0 to disable retries.
38+
39+
@default 2
40+
*/
41+
maxRetries?: number;
42+
43+
/**
44+
Abort signal.
45+
*/
46+
abortSignal?: AbortSignal;
47+
}): Promise<EmbedManyResult<VALUE>> {
48+
const retry = retryWithExponentialBackoff({ maxRetries });
49+
const maxEmbeddingsPerCall = model.maxEmbeddingsPerCall;
50+
51+
// the model has not specified limits on
52+
// how many embeddings can be generated in a single call
53+
if (maxEmbeddingsPerCall == null) {
54+
const modelResponse = await retry(() =>
55+
model.doEmbed({ values, abortSignal }),
56+
);
57+
58+
return new EmbedManyResult({
59+
values,
60+
embeddings: modelResponse.embeddings,
61+
});
62+
}
63+
64+
// split the values into chunks that are small enough for the model:
65+
const valueChunks = splitArray(values, maxEmbeddingsPerCall);
66+
67+
// serially embed the chunks:
68+
const embeddings = [];
69+
for (const chunk of valueChunks) {
70+
const modelResponse = await retry(() =>
71+
model.doEmbed({ values: chunk, abortSignal }),
72+
);
73+
embeddings.push(...modelResponse.embeddings);
74+
}
75+
76+
return new EmbedManyResult({ values, embeddings });
77+
}
78+
79+
/**
80+
The result of a `embedMany` call.
81+
It contains the embeddings, the values, and additional information.
82+
*/
83+
export class EmbedManyResult<VALUE> {
84+
/**
85+
The values that were embedded.
86+
*/
87+
readonly values: Array<VALUE>;
88+
89+
/**
90+
The embeddings. They are in the same order as the values.
91+
*/
92+
readonly embeddings: Array<Embedding>;
93+
94+
constructor(options: { values: Array<VALUE>; embeddings: Array<Embedding> }) {
95+
this.values = options.values;
96+
this.embeddings = options.embeddings;
97+
}
98+
}

Diff for: ‎packages/core/core/embed/embed.test.ts

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import assert from 'node:assert';
2-
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1';
2+
import {
3+
MockEmbeddingModelV1,
4+
mockEmbed,
5+
} from '../test/mock-embedding-model-v1';
36
import { embed } from './embed';
47

58
const dummyEmbedding = [0.1, 0.2, 0.3];
@@ -9,17 +12,24 @@ describe('result.embedding', () => {
912
it('should generate embedding', async () => {
1013
const result = await embed({
1114
model: new MockEmbeddingModelV1({
12-
doEmbed: async ({ values }) => {
13-
assert.deepStrictEqual(values, [testValue]);
14-
15-
return {
16-
embeddings: [dummyEmbedding],
17-
};
18-
},
15+
doEmbed: mockEmbed([testValue], [dummyEmbedding]),
1916
}),
2017
value: testValue,
2118
});
2219

2320
assert.deepStrictEqual(result.embedding, dummyEmbedding);
2421
});
2522
});
23+
24+
describe('result.value', () => {
25+
it('should include value in the result', async () => {
26+
const result = await embed({
27+
model: new MockEmbeddingModelV1({
28+
doEmbed: mockEmbed([testValue], [dummyEmbedding]),
29+
}),
30+
value: testValue,
31+
});
32+
33+
assert.deepStrictEqual(result.value, testValue);
34+
});
35+
});

Diff for: ‎packages/core/core/embed/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export * from './embed';
2+
export * from './embed-many';

Diff for: ‎packages/core/core/test/mock-embedding-model-v1.ts

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { EmbeddingModelV1 } from '@ai-sdk/provider';
2+
import { Embedding } from '../types';
23

34
export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
45
readonly specificationVersion = 'v1';
@@ -31,6 +32,16 @@ export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
3132
}
3233
}
3334

35+
export function mockEmbed<VALUE>(
36+
expectedValues: Array<VALUE>,
37+
embeddings: Array<Embedding>,
38+
): EmbeddingModelV1<VALUE>['doEmbed'] {
39+
return async ({ values }) => {
40+
assert.deepStrictEqual(expectedValues, values);
41+
return { embeddings };
42+
};
43+
}
44+
3445
function notImplemented(): never {
3546
throw new Error('Not implemented');
3647
}

Diff for: ‎packages/core/core/util/split-array.test.ts

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { describe, it, expect } from 'vitest';
2+
import { splitArray } from './split-array';
3+
4+
describe('splitArray', () => {
5+
it('should split an array into chunks of the specified size', () => {
6+
const array = [1, 2, 3, 4, 5];
7+
const size = 2;
8+
const result = splitArray(array, size);
9+
expect(result).toEqual([[1, 2], [3, 4], [5]]);
10+
});
11+
12+
it('should return an empty array when the input array is empty', () => {
13+
const array: number[] = [];
14+
const size = 2;
15+
const result = splitArray(array, size);
16+
expect(result).toEqual([]);
17+
});
18+
19+
it('should return the original array when the chunk size is greater than the array length', () => {
20+
const array = [1, 2, 3];
21+
const size = 5;
22+
const result = splitArray(array, size);
23+
expect(result).toEqual([[1, 2, 3]]);
24+
});
25+
26+
it('should return the original array when the chunk size is equal to the array length', () => {
27+
const array = [1, 2, 3];
28+
const size = 3;
29+
const result = splitArray(array, size);
30+
expect(result).toEqual([[1, 2, 3]]);
31+
});
32+
33+
it('should handle chunk size of 1 correctly', () => {
34+
const array = [1, 2, 3];
35+
const size = 1;
36+
const result = splitArray(array, size);
37+
expect(result).toEqual([[1], [2], [3]]);
38+
});
39+
40+
it('should throw an error for chunk size of 0', () => {
41+
const array = [1, 2, 3];
42+
const size = 0;
43+
expect(() => splitArray(array, size)).toThrow(
44+
'chunkSize must be greater than 0',
45+
);
46+
});
47+
48+
it('should throw an error for negative chunk size', () => {
49+
const array = [1, 2, 3];
50+
const size = -1;
51+
expect(() => splitArray(array, size)).toThrow(
52+
'chunkSize must be greater than 0',
53+
);
54+
});
55+
56+
it('should handle non-integer chunk size by flooring the size', () => {
57+
const array = [1, 2, 3, 4, 5];
58+
const size = 2.5;
59+
const result = splitArray(array, Math.floor(size));
60+
expect(result).toEqual([[1, 2], [3, 4], [5]]);
61+
});
62+
});

Diff for: ‎packages/core/core/util/split-array.ts

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/**
2+
* Splits an array into chunks of a specified size.
3+
*
4+
* @template T - The type of elements in the array.
5+
* @param {T[]} array - The array to split.
6+
* @param {number} chunkSize - The size of each chunk.
7+
* @returns {T[][]} - A new array containing the chunks.
8+
*/
9+
export function splitArray<T>(array: T[], chunkSize: number): T[][] {
10+
if (chunkSize <= 0) {
11+
throw new Error('chunkSize must be greater than 0');
12+
}
13+
14+
const result = [];
15+
for (let i = 0; i < array.length; i += chunkSize) {
16+
result.push(array.slice(i, i + chunkSize));
17+
}
18+
19+
return result;
20+
}

0 commit comments

Comments
 (0)
Please sign in to comment.