Skip to content

Commit 4728c37

Browse files
lgrammeliteratetograceness
andauthoredJun 14, 2024··
feat (core): add text embedding model support to provider registry (#1959)
Co-authored-by: Grace Yun <74513600+iteratetograceness@users.noreply.github.com>
1 parent 1121364 commit 4728c37

File tree

16 files changed

+345
-113
lines changed

16 files changed

+345
-113
lines changed
 

‎.changeset/brave-poets-worry.md

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
'@ai-sdk/google-vertex': patch
3+
'@ai-sdk/anthropic': patch
4+
'@ai-sdk/mistral': patch
5+
'@ai-sdk/google': patch
6+
'@ai-sdk/openai': patch
7+
'@ai-sdk/azure': patch
8+
'ai': patch
9+
---
10+
11+
feat (core): add text embedding model support to provider registry

‎content/docs/03-ai-sdk-core/40-provider-management.mdx

+17-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The Vercel AI SDK provides a [`ProviderRegistry`](/docs/reference/ai-sdk-core/pr
1414
You can register multiple providers. The provider id will become the prefix of the model id:
1515
`providerId:modelId`.
1616

17-
### Setup (Example)
17+
### Setup
1818

1919
You can create a registry with multiple providers and models using `experimental_createProviderRegistry`.
2020

@@ -39,7 +39,7 @@ export const registry = createProviderRegistry({
3939
});
4040
```
4141

42-
### Usage (Example)
42+
### Language models
4343

4444
You can access language models by using the `languageModel` method on the registry.
4545
The provider id will become the prefix of the model id: `providerId:modelId`.
@@ -53,3 +53,18 @@ const { text } = await generateText({
5353
prompt: 'Invent a new holiday and describe its traditions.',
5454
});
5555
```
56+
57+
### Text embedding models
58+
59+
You can access text embedding models by using the `textEmbeddingModel` method on the registry.
60+
The provider id will become the prefix of the model id: `providerId:modelId`.
61+
62+
```ts highlight={"5"}
63+
import { embed } from 'ai';
64+
import { registry } from './registry';
65+
66+
const { embedding } = await embed({
67+
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
68+
value: 'sunny day at the beach',
69+
});
70+
```

‎content/docs/07-reference/ai-sdk-core/40-provider-registry.mdx

+58-35
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@ in a central place and access the models through simple string ids.
1313
`createProviderRegistry` lets you create a registry with multiple providers that you
1414
can access by their ids.
1515

16-
### Setup (Example)
17-
18-
You can create a registry with multiple providers and models using `createProviderRegistry`.
19-
20-
```ts
21-
import { anthropic } from '@ai-sdk/anthropic';
22-
import { createOpenAI } from '@ai-sdk/openai';
23-
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';
24-
25-
export const registry = createProviderRegistry({
26-
// register provider with prefix and default setup:
27-
anthropic,
28-
29-
// register provider with prefix and custom setup:
30-
openai: createOpenAI({
31-
apiKey: process.env.OPENAI_API_KEY,
32-
}),
33-
});
34-
```
35-
36-
### Usage (Example)
37-
38-
You can access language models by using the `languageModel` method on the registry.
39-
The provider id will become the prefix of the model id: `providerId:modelId`.
40-
41-
```ts highlight={"4"}
42-
import { generateText } from 'ai';
43-
44-
const { text } = await generateText({
45-
model: registry.languageModel('openai:gpt-4-turbo'),
46-
prompt: 'Invent a new holiday and describe its traditions.',
47-
});
48-
```
49-
5016
## Import
5117

5218
<Snippet
@@ -64,7 +30,7 @@ Registers a language model provider with a given id.
6430
content={[
6531
{
6632
name: 'providers',
67-
type: 'Record<string, (id: string) => LanguageModel>',
33+
type: 'Record<string, { languageModel: (id: string) => LanguageModel; textEmbedding: (id: string) => EmbeddingModel<string> }>',
6834
description: `The unique identifier for the provider. It should be unique within the registry.`,
6935
},
7036
]}
@@ -81,5 +47,62 @@ The `experimental_createProviderRegistry` function returns a `experimental_Provi
8147
type: '(id: string) => LanguageModel',
8248
description: `A function that returns a language model by its id (format: providerId:modelId)`,
8349
},
50+
{
51+
name: 'textEmbeddingModel',
52+
type: '(id: string) => EmbeddingModel<string>',
53+
description: `A function that returns a text embedding model by its id (format: providerId:modelId)`,
54+
},
8455
]}
8556
/>
57+
58+
## Examples
59+
60+
### Setup
61+
62+
You can create a registry with multiple providers and models using `createProviderRegistry`.
63+
64+
```ts
65+
import { anthropic } from '@ai-sdk/anthropic';
66+
import { createOpenAI } from '@ai-sdk/openai';
67+
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';
68+
69+
export const registry = createProviderRegistry({
70+
// register provider with prefix and default setup:
71+
anthropic,
72+
73+
// register provider with prefix and custom setup:
74+
openai: createOpenAI({
75+
apiKey: process.env.OPENAI_API_KEY,
76+
}),
77+
});
78+
```
79+
80+
### Language models
81+
82+
You can access language models by using the `languageModel` method on the registry.
83+
The provider id will become the prefix of the model id: `providerId:modelId`.
84+
85+
```ts highlight={"5"}
86+
import { generateText } from 'ai';
87+
import { registry } from './registry';
88+
89+
const { text } = await generateText({
90+
model: registry.languageModel('openai:gpt-4-turbo'),
91+
prompt: 'Invent a new holiday and describe its traditions.',
92+
});
93+
```
94+
95+
### Text embedding models
96+
97+
You can access text embedding models by using the `textEmbeddingModel` method on the registry.
98+
The provider id will become the prefix of the model id: `providerId:modelId`.
99+
100+
```ts highlight={"5"}
101+
import { embed } from 'ai';
102+
import { registry } from './registry';
103+
104+
const { embedding } = await embed({
105+
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
106+
value: 'sunny day at the beach',
107+
});
108+
```
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { embed } from 'ai';
2+
import { registry } from './setup-registry';
3+
4+
async function main() {
5+
const { embedding } = await embed({
6+
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
7+
value: 'sunny day at the beach',
8+
});
9+
10+
console.log(embedding);
11+
}
12+
13+
main().catch(console.error);

‎examples/ai-core/src/registry/setup-registry.ts

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { anthropic } from '@ai-sdk/anthropic';
2+
import { mistral } from '@ai-sdk/mistral';
23
import { createOpenAI } from '@ai-sdk/openai';
34
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';
45
import dotenv from 'dotenv';
@@ -8,6 +9,7 @@ dotenv.config();
89
export const registry = createProviderRegistry({
910
// register provider with prefix and default setup:
1011
anthropic,
12+
mistral,
1113

1214
// register provider with prefix and custom setup:
1315
openai: createOpenAI({

‎packages/anthropic/src/anthropic-provider.ts

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ Creates a model for text generation.
1616

1717
/**
1818
Creates a model for text generation.
19+
*/
20+
languageModel(
21+
modelId: AnthropicMessagesModelId,
22+
settings?: AnthropicMessagesSettings,
23+
): AnthropicMessagesLanguageModel;
24+
25+
/**
26+
Creates a model for text generation.
1927
*/
2028
chat(
2129
modelId: AnthropicMessagesModelId,
@@ -108,6 +116,7 @@ export function createAnthropic(
108116
return createChatModel(modelId, settings);
109117
};
110118

119+
provider.languageModel = createChatModel;
111120
provider.chat = createChatModel;
112121
provider.messages = createChatModel;
113122

‎packages/azure/src/azure-openai-provider.ts

+9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ export interface AzureOpenAIProvider {
1010
settings?: OpenAIChatSettings,
1111
): OpenAIChatLanguageModel;
1212

13+
/**
14+
Creates an Azure OpenAI chat model for text generation.
15+
*/
16+
languageModel(
17+
deploymentId: string,
18+
settings?: OpenAIChatSettings,
19+
): OpenAIChatLanguageModel;
20+
1321
/**
1422
Creates an Azure OpenAI chat model for text generation.
1523
*/
@@ -85,6 +93,7 @@ export function createAzure(
8593
return createChatModel(deploymentId, settings as OpenAIChatSettings);
8694
};
8795

96+
provider.languageModel = createChatModel;
8897
provider.chat = createChatModel;
8998

9099
return provider as AzureOpenAIProvider;

‎packages/core/core/registry/no-such-model-error.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
export class NoSuchModelError extends Error {
22
readonly modelId: string;
3+
readonly modelType: string;
34

45
constructor({
56
modelId,
6-
message = `No such model: ${modelId}`,
7+
modelType,
8+
message = `No such ${modelType}: ${modelId}`,
79
}: {
810
modelId: string;
11+
modelType: string;
912
message?: string;
1013
}) {
1114
super(message);
1215

1316
this.name = 'AI_NoSuchModelError';
1417

1518
this.modelId = modelId;
19+
this.modelType = modelType;
1620
}
1721

1822
static isNoSuchModelError(error: unknown): error is NoSuchModelError {
1923
return (
2024
error instanceof Error &&
2125
error.name === 'AI_NoSuchModelError' &&
22-
typeof (error as NoSuchModelError).modelId === 'string'
26+
typeof (error as NoSuchModelError).modelId === 'string' &&
27+
typeof (error as NoSuchModelError).modelType === 'string'
2328
);
2429
}
2530

@@ -30,6 +35,7 @@ export class NoSuchModelError extends Error {
3035
stack: this.stack,
3136

3237
modelId: this.modelId,
38+
modelType: this.modelType,
3339
};
3440
}
3541
}

‎packages/core/core/registry/no-such-provider-error.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
export class NoSuchProviderError extends Error {
22
readonly providerId: string;
3+
readonly availableProviders: string[];
34

45
constructor({
56
providerId,
6-
message = `No such provider: ${providerId}`,
7+
availableProviders,
8+
message = `No such provider: ${providerId} (available providers: ${availableProviders.join()})`,
79
}: {
810
providerId: string;
11+
availableProviders: string[];
912
message?: string;
1013
}) {
1114
super(message);
1215

1316
this.name = 'AI_NoSuchProviderError';
1417

1518
this.providerId = providerId;
19+
this.availableProviders = availableProviders;
1620
}
1721

1822
static isNoSuchProviderError(error: unknown): error is NoSuchProviderError {
1923
return (
2024
error instanceof Error &&
2125
error.name === 'AI_NoSuchProviderError' &&
22-
typeof (error as NoSuchProviderError).providerId === 'string'
26+
typeof (error as NoSuchProviderError).providerId === 'string' &&
27+
Array.isArray((error as NoSuchProviderError).availableProviders)
2328
);
2429
}
2530

@@ -30,6 +35,7 @@ export class NoSuchProviderError extends Error {
3035
stack: this.stack,
3136

3237
providerId: this.providerId,
38+
availableProviders: this.availableProviders,
3339
};
3440
}
3541
}
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,100 @@
1+
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1';
12
import { MockLanguageModelV1 } from '../test/mock-language-model-v1';
23
import { InvalidModelIdError } from './invalid-model-id-error';
34
import { NoSuchModelError } from './no-such-model-error';
45
import { NoSuchProviderError } from './no-such-provider-error';
56
import { experimental_createProviderRegistry } from './provider-registry';
67

7-
it('should return language model from provider', () => {
8-
const model = new MockLanguageModelV1();
8+
describe('languageModel', () => {
9+
it('should return language model from provider', () => {
10+
const model = new MockLanguageModelV1();
911

10-
const modelRegistry = experimental_createProviderRegistry({
11-
provider: id => {
12-
expect(id).toEqual('model');
13-
return model;
14-
},
12+
const modelRegistry = experimental_createProviderRegistry({
13+
provider: {
14+
languageModel: id => {
15+
expect(id).toEqual('model');
16+
return model;
17+
},
18+
},
19+
});
20+
21+
expect(modelRegistry.languageModel('provider:model')).toEqual(model);
1522
});
1623

17-
expect(modelRegistry.languageModel('provider:model')).toEqual(model);
18-
});
24+
it('should throw NoSuchProviderError if provider does not exist', () => {
25+
const registry = experimental_createProviderRegistry({});
1926

20-
it('should throw NoSuchProviderError if provider does not exist', () => {
21-
const registry = experimental_createProviderRegistry({});
27+
expect(() => registry.languageModel('provider:model')).toThrowError(
28+
NoSuchProviderError,
29+
);
30+
});
2231

23-
expect(() => registry.languageModel('provider:model')).toThrowError(
24-
NoSuchProviderError,
25-
);
26-
});
32+
it('should throw NoSuchModelError if provider does not return a model', () => {
33+
const registry = experimental_createProviderRegistry({
34+
provider: {
35+
languageModel: () => {
36+
return null as any;
37+
},
38+
},
39+
});
2740

28-
it('should throw NoSuchModelError if provider does not return a model', () => {
29-
const registry = experimental_createProviderRegistry({
30-
provider: () => null as any,
41+
expect(() => registry.languageModel('provider:model')).toThrowError(
42+
NoSuchModelError,
43+
);
3144
});
3245

33-
expect(() => registry.languageModel('provider:model')).toThrowError(
34-
NoSuchModelError,
35-
);
46+
it("should throw InvalidModelIdError if model id doesn't contain a colon", () => {
47+
const registry = experimental_createProviderRegistry({});
48+
49+
expect(() => registry.languageModel('model')).toThrowError(
50+
InvalidModelIdError,
51+
);
52+
});
3653
});
3754

38-
it("should throw InvalidModelIdError if model id doesn't contain a colon", () => {
39-
const registry = experimental_createProviderRegistry({});
55+
describe('textEmbeddingModel', () => {
56+
it('should return embedding model from provider', () => {
57+
const model = new MockEmbeddingModelV1<string>();
58+
59+
const modelRegistry = experimental_createProviderRegistry({
60+
provider: {
61+
textEmbedding: id => {
62+
expect(id).toEqual('model');
63+
return model;
64+
},
65+
},
66+
});
67+
68+
expect(modelRegistry.textEmbeddingModel('provider:model')).toEqual(model);
69+
});
4070

41-
expect(() => registry.languageModel('model')).toThrowError(
42-
InvalidModelIdError,
43-
);
71+
it('should throw NoSuchProviderError if provider does not exist', () => {
72+
const registry = experimental_createProviderRegistry({});
73+
74+
expect(() => registry.textEmbeddingModel('provider:model')).toThrowError(
75+
NoSuchProviderError,
76+
);
77+
});
78+
79+
it('should throw NoSuchModelError if provider does not return a model', () => {
80+
const registry = experimental_createProviderRegistry({
81+
provider: {
82+
textEmbedding: () => {
83+
return null as any;
84+
},
85+
},
86+
});
87+
88+
expect(() => registry.languageModel('provider:model')).toThrowError(
89+
NoSuchModelError,
90+
);
91+
});
92+
93+
it("should throw InvalidModelIdError if model id doesn't contain a colon", () => {
94+
const registry = experimental_createProviderRegistry({});
95+
96+
expect(() => registry.textEmbeddingModel('model')).toThrowError(
97+
InvalidModelIdError,
98+
);
99+
});
44100
});

‎packages/core/core/registry/provider-registry.ts

+69-42
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { LanguageModel } from '../types';
1+
import { EmbeddingModel, LanguageModel } from '../types';
22
import { InvalidModelIdError } from './invalid-model-id-error';
33
import { NoSuchModelError } from './no-such-model-error';
44
import { NoSuchProviderError } from './no-such-provider-error';
@@ -19,86 +19,113 @@ The model id is then passed to the provider function to get the model.
1919
@returns {LanguageModel} The language model associated with the id.
2020
*/
2121
languageModel(id: string): LanguageModel;
22+
23+
/**
24+
Returns the text embedding model with the given id in the format `providerId:modelId`.
25+
The model id is then passed to the provider function to get the model.
26+
27+
@param {string} id - The id of the model to return.
28+
29+
@throws {NoSuchModelError} If no model with the given id exists.
30+
@throws {NoSuchProviderError} If no provider with the given id exists.
31+
32+
@returns {LanguageModel} The language model associated with the id.
33+
*/
34+
textEmbeddingModel(id: string): EmbeddingModel<string>;
2235
};
2336

2437
/**
2538
* @deprecated Use `experimental_ProviderRegistry` instead.
2639
*/
2740
export type experimental_ModelRegistry = experimental_ProviderRegistry;
2841

42+
/**
43+
* Provider for language and text embedding models. Compatible with the
44+
* provider registry.
45+
*/
46+
interface Provider {
47+
/**
48+
* Returns a language model with the given id.
49+
*/
50+
languageModel?: (modelId: string) => LanguageModel;
51+
52+
/**
53+
* Returns a text embedding model with the given id.
54+
*/
55+
textEmbedding?: (modelId: string) => EmbeddingModel<string>;
56+
}
57+
2958
/**
3059
* Creates a registry for the given providers.
3160
*/
3261
export function experimental_createProviderRegistry(
33-
providers: Record<string, (id: string) => LanguageModel>,
62+
providers: Record<string, Provider>,
3463
): experimental_ProviderRegistry {
3564
const registry = new DefaultProviderRegistry();
3665

3766
for (const [id, provider] of Object.entries(providers)) {
38-
registry.registerLanguageModelProvider({ id, provider });
67+
registry.registerProvider({ id, provider });
3968
}
4069

4170
return registry;
4271
}
4372

44-
class DefaultProviderRegistry implements experimental_ProviderRegistry {
45-
// Mapping of provider id to provider
46-
private providers: Record<string, (id: string) => LanguageModel> = {};
73+
/**
74+
* @deprecated Use `experimental_createProviderRegistry` instead.
75+
*/
76+
export const experimental_createModelRegistry =
77+
experimental_createProviderRegistry;
4778

48-
/**
49-
Registers a language model provider with a given id.
79+
class DefaultProviderRegistry implements experimental_ProviderRegistry {
80+
private providers: Record<string, Provider> = {};
5081

51-
@param {string} id - The id of the provider.
52-
@param {(id: string) => LanguageModel} provider - The provider function to register.
53-
*/
54-
registerLanguageModelProvider({
55-
id,
56-
provider,
57-
}: {
58-
id: string;
59-
provider: (id: string) => LanguageModel;
60-
}): void {
82+
registerProvider({ id, provider }: { id: string; provider: Provider }): void {
6183
this.providers[id] = provider;
6284
}
6385

64-
/**
65-
Returns the language model with the given id.
66-
The id can either be a registered model id or use a provider prefix.
67-
Provider ids are separated from the model id by a colon: `providerId:modelId`.
68-
The model id is then passed to the provider function to get the model.
86+
private getProvider(id: string): Provider {
87+
const provider = this.providers[id];
6988

70-
@param {string} id - The id of the model to return.
89+
if (provider == null) {
90+
throw new NoSuchProviderError({
91+
providerId: id,
92+
availableProviders: Object.keys(this.providers),
93+
});
94+
}
7195

72-
@throws {NoSuchModelError} If no model with the given id exists.
73-
@throws {NoSuchProviderError} If no provider with the given id exists.
96+
return provider;
97+
}
7498

75-
@returns {LanguageModel} The language model associated with the id.
76-
*/
77-
languageModel(id: string): LanguageModel {
99+
private splitId(id: string): [string, string] {
78100
if (!id.includes(':')) {
79101
throw new InvalidModelIdError({ id });
80102
}
81103

82-
const [providerId, modelId] = id.split(':');
104+
return id.split(':') as [string, string];
105+
}
83106

84-
const provider = this.providers[providerId];
107+
languageModel(id: string): LanguageModel {
108+
const [providerId, modelId] = this.splitId(id);
109+
const model = this.getProvider(providerId).languageModel?.(modelId);
85110

86-
if (!provider) {
87-
throw new NoSuchProviderError({ providerId });
111+
if (model == null) {
112+
throw new NoSuchModelError({ modelId: id, modelType: 'language model' });
88113
}
89114

90-
const model = provider(modelId);
115+
return model;
116+
}
117+
118+
textEmbeddingModel(id: string): EmbeddingModel<string> {
119+
const [providerId, modelId] = this.splitId(id);
120+
const model = this.getProvider(providerId).textEmbedding?.(modelId);
91121

92-
if (!model) {
93-
throw new NoSuchModelError({ modelId: id });
122+
if (model == null) {
123+
throw new NoSuchModelError({
124+
modelId: id,
125+
modelType: 'text embedding model',
126+
});
94127
}
95128

96129
return model;
97130
}
98131
}
99-
100-
/**
101-
* @deprecated Use `experimental_createProviderRegistry` instead.
102-
*/
103-
export const experimental_createModelRegistry =
104-
experimental_createProviderRegistry;

‎packages/core/core/test/mock-embedding-model-v1.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
2323
maxEmbeddingsPerCall?: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
2424
supportsParallelCalls?: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
2525
doEmbed?: EmbeddingModelV1<VALUE>['doEmbed'];
26-
}) {
26+
} = {}) {
2727
this.provider = provider;
2828
this.modelId = modelId;
2929
this.maxEmbeddingsPerCall = maxEmbeddingsPerCall;

‎packages/google-vertex/src/google-vertex-provider.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ Creates a model for text generation.
1414
modelId: GoogleVertexModelId,
1515
settings?: GoogleVertexSettings,
1616
): GoogleVertexLanguageModel;
17+
18+
languageModel: (
19+
modelId: GoogleVertexModelId,
20+
settings?: GoogleVertexSettings,
21+
) => GoogleVertexLanguageModel;
1722
}
1823

1924
export interface GoogleVertexProviderSettings {
@@ -96,7 +101,7 @@ export function createVertex(
96101
return createChatModel(modelId, settings);
97102
};
98103

99-
provider.chat = createChatModel;
104+
provider.languageModel = createChatModel;
100105

101106
return provider as GoogleVertexProvider;
102107
}

‎packages/google/src/google-provider.ts

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ export interface GoogleGenerativeAIProvider {
1515
settings?: GoogleGenerativeAISettings,
1616
): GoogleGenerativeAILanguageModel;
1717

18+
languageModel(
19+
modelId: GoogleGenerativeAIModelId,
20+
settings?: GoogleGenerativeAISettings,
21+
): GoogleGenerativeAILanguageModel;
22+
1823
chat(
1924
modelId: GoogleGenerativeAIModelId,
2025
settings?: GoogleGenerativeAISettings,
@@ -105,6 +110,7 @@ export function createGoogleGenerativeAI(
105110
return createChatModel(modelId, settings);
106111
};
107112

113+
provider.languageModel = createChatModel;
108114
provider.chat = createChatModel;
109115
provider.generativeAI = createChatModel;
110116

‎packages/mistral/src/mistral-provider.ts

+18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ export interface MistralProvider {
2222

2323
/**
2424
Creates a model for text generation.
25+
*/
26+
languageModel(
27+
modelId: MistralChatModelId,
28+
settings?: MistralChatSettings,
29+
): MistralChatLanguageModel;
30+
31+
/**
32+
Creates a model for text generation.
2533
*/
2634
chat(
2735
modelId: MistralChatModelId,
@@ -35,6 +43,14 @@ Creates a model for text embeddings.
3543
modelId: MistralEmbeddingModelId,
3644
settings?: MistralEmbeddingSettings,
3745
): MistralEmbeddingModel;
46+
47+
/**
48+
Creates a model for text embeddings.
49+
*/
50+
textEmbedding(
51+
modelId: MistralEmbeddingModelId,
52+
settings?: MistralEmbeddingSettings,
53+
): MistralEmbeddingModel;
3854
}
3955

4056
export interface MistralProviderSettings {
@@ -124,8 +140,10 @@ export function createMistral(
124140
return createChatModel(modelId, settings);
125141
};
126142

143+
provider.languageModel = createChatModel;
127144
provider.chat = createChatModel;
128145
provider.embedding = createEmbeddingModel;
146+
provider.textEmbedding = createEmbeddingModel;
129147

130148
return provider as MistralProvider;
131149
}

‎packages/openai/src/openai-provider.ts

+28-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ export interface OpenAIProvider {
2222
settings?: OpenAIChatSettings,
2323
): OpenAIChatLanguageModel;
2424

25+
languageModel(
26+
modelId: 'gpt-3.5-turbo-instruct',
27+
settings?: OpenAICompletionSettings,
28+
): OpenAICompletionLanguageModel;
29+
languageModel(
30+
modelId: OpenAIChatModelId,
31+
settings?: OpenAIChatSettings,
32+
): OpenAIChatLanguageModel;
33+
2534
/**
2635
Creates an OpenAI chat model for text generation.
2736
*/
@@ -45,6 +54,14 @@ Creates a model for text embeddings.
4554
modelId: OpenAIEmbeddingModelId,
4655
settings?: OpenAIEmbeddingSettings,
4756
): OpenAIEmbeddingModel;
57+
58+
/**
59+
Creates a model for text embeddings.
60+
*/
61+
textEmbedding(
62+
modelId: OpenAIEmbeddingModelId,
63+
settings?: OpenAIEmbeddingSettings,
64+
): OpenAIEmbeddingModel;
4865
}
4966

5067
export interface OpenAIProviderSettings {
@@ -151,10 +168,10 @@ export function createOpenAI(
151168
fetch: options.fetch,
152169
});
153170

154-
const provider = function (
171+
const createLanguageModel = (
155172
modelId: OpenAIChatModelId | OpenAICompletionModelId,
156173
settings?: OpenAIChatSettings | OpenAICompletionSettings,
157-
) {
174+
) => {
158175
if (new.target) {
159176
throw new Error(
160177
'The OpenAI model function cannot be called with the new keyword.',
@@ -171,9 +188,18 @@ export function createOpenAI(
171188
return createChatModel(modelId, settings as OpenAIChatSettings);
172189
};
173190

191+
const provider = function (
192+
modelId: OpenAIChatModelId | OpenAICompletionModelId,
193+
settings?: OpenAIChatSettings | OpenAICompletionSettings,
194+
) {
195+
return createLanguageModel(modelId, settings);
196+
};
197+
198+
provider.languageModel = createLanguageModel;
174199
provider.chat = createChatModel;
175200
provider.completion = createCompletionModel;
176201
provider.embedding = createEmbeddingModel;
202+
provider.textEmbedding = createEmbeddingModel;
177203

178204
return provider as OpenAIProvider;
179205
}

0 commit comments

Comments
 (0)
Please sign in to comment.