Skip to content

Commit b729094

Browse files
lgrammeltobiasstrebitzer
andauthoredJul 9, 2024··
feat (ai/core): add usage information for embed and embedMany (#2220)
Co-authored-by: Tobias Strebitzer <tobias.strebitzer@magloft.com>
1 parent e09c0b9 commit b729094

34 files changed

+281
-84
lines changed
 

‎.changeset/cool-snakes-agree.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
chore (ai/core): rename TokenUsage type to CompletionTokenUsage

‎.changeset/unlucky-owls-admire.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
'@ai-sdk/provider': patch
3+
'@ai-sdk/mistral': patch
4+
'@ai-sdk/openai': patch
5+
'ai': patch
6+
---
7+
8+
feat (ai/core): add token usage to embed and embedMany

‎content/docs/03-ai-sdk-core/30-embeddings.mdx

+17
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,20 @@ console.log(
6969
`cosine similarity: ${cosineSimilarity(embeddings[0], embeddings[1])}`,
7070
);
7171
```
72+
73+
## Token Usage
74+
75+
Many providers charge based on the number of tokens used to generate embeddings.
76+
Both `embed` and `embedMany` provide token usage information in the `usage` property of the result object:
77+
78+
```ts highlight={"4,9"}
79+
import { openai } from '@ai-sdk/openai';
80+
import { embed } from 'ai';
81+
82+
const { embedding, usage } = await embed({
83+
model: openai.embedding('text-embedding-3-small'),
84+
value: 'sunny day at the beach',
85+
});
86+
87+
console.log(usage); // { tokens: 10 }
88+
```

‎content/docs/07-reference/ai-sdk-core/01-generate-text.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -367,11 +367,11 @@ console.log(text);
367367
},
368368
{
369369
name: 'usage',
370-
type: 'TokenUsage',
370+
type: 'CompletionTokenUsage',
371371
description: 'The token usage of the generated text.',
372372
properties: [
373373
{
374-
type: 'TokenUsage',
374+
type: 'CompletionTokenUsage',
375375
parameters: [
376376
{
377377
name: 'promptTokens',

‎content/docs/07-reference/ai-sdk-core/02-stream-text.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,12 @@ for await (const textPart of textStream) {
433433
},
434434
{
435435
name: 'usage',
436-
type: 'Promise<TokenUsage>',
436+
type: 'Promise<CompletionTokenUsage>',
437437
description:
438438
'The token usage of the generated text. Resolved when the response is finished.',
439439
properties: [
440440
{
441-
type: 'TokenUsage',
441+
type: 'CompletionTokenUsage',
442442
parameters: [
443443
{
444444
name: 'promptTokens',

‎content/docs/07-reference/ai-sdk-core/03-generate-object.mdx

+2-2
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,11 @@ console.log(JSON.stringify(object, null, 2));
329329
},
330330
{
331331
name: 'usage',
332-
type: 'TokenUsage',
332+
type: 'CompletionTokenUsage',
333333
description: 'The token usage of the generated text.',
334334
properties: [
335335
{
336-
type: 'TokenUsage',
336+
type: 'CompletionTokenUsage',
337337
parameters: [
338338
{
339339
name: 'promptTokens',

‎content/docs/07-reference/ai-sdk-core/04-stream-object.mdx

+4-4
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,11 @@ for await (const partialObject of partialObjectStream) {
325325
parameters: [
326326
{
327327
name: 'usage',
328-
type: 'TokenUsage',
328+
type: 'CompletionTokenUsage',
329329
description: 'The token usage of the generated text.',
330330
properties: [
331331
{
332-
type: 'TokenUsage',
332+
type: 'CompletionTokenUsage',
333333
parameters: [
334334
{
335335
name: 'promptTokens',
@@ -400,12 +400,12 @@ for await (const partialObject of partialObjectStream) {
400400
content={[
401401
{
402402
name: 'usage',
403-
type: 'Promise<TokenUsage>',
403+
type: 'Promise<CompletionTokenUsage>',
404404
description:
405405
'The token usage of the generated text. Resolved when the response is finished.',
406406
properties: [
407407
{
408-
type: 'TokenUsage',
408+
type: 'CompletionTokenUsage',
409409
parameters: [
410410
{
411411
name: 'promptTokens',

‎content/docs/07-reference/ai-sdk-core/05-embed.mdx

+17
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,23 @@ const { embedding } = await embed({
7777
type: 'number[]',
7878
description: 'The embedding of the value.',
7979
},
80+
{
81+
name: 'usage',
82+
type: 'EmbeddingTokenUsage',
83+
description: 'The token usage for generating the embeddings.',
84+
properties: [
85+
{
86+
type: 'EmbeddingTokenUsage',
87+
parameters: [
88+
{
89+
name: 'tokens',
90+
type: 'number',
91+
description: 'The total number of input tokens.',
92+
},
93+
],
94+
},
95+
],
96+
},
8097
{
8198
name: 'rawResponse',
8299
type: 'RawResponse',

‎content/docs/07-reference/ai-sdk-core/06-embed-many.mdx

+17
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,22 @@ const { embeddings } = await embedMany({
8383
type: 'number[][]',
8484
description: 'The embeddings. They are in the same order as the values.',
8585
},
86+
{
87+
name: 'usage',
88+
type: 'EmbeddingTokenUsage',
89+
description: 'The token usage for generating the embeddings.',
90+
properties: [
91+
{
92+
type: 'EmbeddingTokenUsage',
93+
parameters: [
94+
{
95+
name: 'tokens',
96+
type: 'number',
97+
description: 'The total number of input tokens.',
98+
},
99+
],
100+
},
101+
],
102+
},
86103
]}
87104
/>

‎examples/ai-core/src/embed-many/azure.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embeddings } = await embedMany({
8+
const { embeddings, usage } = await embedMany({
99
model: azure.embedding('my-embedding-deployment'),
1010
values: [
1111
'sunny day at the beach',
@@ -15,6 +15,7 @@ async function main() {
1515
});
1616

1717
console.log(embeddings);
18+
console.log(usage);
1819
}
1920

2021
main().catch(console.error);

‎examples/ai-core/src/embed-many/mistral.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embeddings } = await embedMany({
8+
const { embeddings, usage } = await embedMany({
99
model: mistral.embedding('mistral-embed'),
1010
values: [
1111
'sunny day at the beach',
@@ -15,6 +15,7 @@ async function main() {
1515
});
1616

1717
console.log(embeddings);
18+
console.log(usage);
1819
}
1920

2021
main().catch(console.error);

‎examples/ai-core/src/embed-many/openai.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embeddings } = await embedMany({
8+
const { embeddings, usage } = await embedMany({
99
model: openai.embedding('text-embedding-3-small'),
1010
values: [
1111
'sunny day at the beach',
@@ -15,6 +15,7 @@ async function main() {
1515
});
1616

1717
console.log(embeddings);
18+
console.log(usage);
1819
}
1920

2021
main().catch(console.error);

‎examples/ai-core/src/embed/azure.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embedding } = await embed({
8+
const { embedding, usage } = await embed({
99
model: azure.embedding('my-embedding-deployment'),
1010
value: 'sunny day at the beach',
1111
});
12+
1213
console.log(embedding);
14+
console.log(usage);
1315
}
1416

1517
main().catch(console.error);

‎examples/ai-core/src/embed/mistral.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embedding } = await embed({
8+
const { embedding, usage } = await embed({
99
model: mistral.embedding('mistral-embed'),
1010
value: 'sunny day at the beach',
1111
});
1212

1313
console.log(embedding);
14+
console.log(usage);
1415
}
1516

1617
main().catch(console.error);

‎examples/ai-core/src/embed/openai.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ import dotenv from 'dotenv';
55
dotenv.config();
66

77
async function main() {
8-
const { embedding } = await embed({
8+
const { embedding, usage } = await embed({
99
model: openai.embedding('text-embedding-3-small'),
1010
value: 'sunny day at the beach',
1111
});
1212

1313
console.log(embedding);
14+
console.log(usage);
1415
}
1516

1617
main().catch(console.error);

‎packages/core/core/embed/embed-many.test.ts

+40-12
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,16 @@ describe('result.embedding', () => {
3737
model: new MockEmbeddingModelV1({
3838
maxEmbeddingsPerCall: 2,
3939
doEmbed: async ({ values }) => {
40-
if (callCount === 0) {
41-
assert.deepStrictEqual(values, testValues.slice(0, 2));
42-
callCount++;
43-
return { embeddings: dummyEmbeddings.slice(0, 2) };
40+
switch (callCount++) {
41+
case 0:
42+
assert.deepStrictEqual(values, testValues.slice(0, 2));
43+
return { embeddings: dummyEmbeddings.slice(0, 2) };
44+
case 1:
45+
assert.deepStrictEqual(values, testValues.slice(2));
46+
return { embeddings: dummyEmbeddings.slice(2) };
47+
default:
48+
throw new Error('Unexpected call');
4449
}
45-
46-
if (callCount === 1) {
47-
assert.deepStrictEqual(values, testValues.slice(2));
48-
callCount++;
49-
return { embeddings: dummyEmbeddings.slice(2) };
50-
}
51-
52-
throw new Error('Unexpected call');
5350
},
5451
}),
5552
values: testValues,
@@ -73,6 +70,37 @@ describe('result.values', () => {
7370
});
7471
});
7572

73+
describe('result.usage', () => {
74+
it('should include usage in the result', async () => {
75+
let callCount = 0;
76+
77+
const result = await embedMany({
78+
model: new MockEmbeddingModelV1({
79+
maxEmbeddingsPerCall: 2,
80+
doEmbed: async () => {
81+
switch (callCount++) {
82+
case 0:
83+
return {
84+
embeddings: dummyEmbeddings.slice(0, 2),
85+
usage: { tokens: 10 },
86+
};
87+
case 1:
88+
return {
89+
embeddings: dummyEmbeddings.slice(2),
90+
usage: { tokens: 20 },
91+
};
92+
default:
93+
throw new Error('Unexpected call');
94+
}
95+
},
96+
}),
97+
values: testValues,
98+
});
99+
100+
assert.deepStrictEqual(result.usage, { tokens: 30 });
101+
});
102+
});
103+
76104
describe('options.headers', () => {
77105
it('should set headers', async () => {
78106
const result = await embedMany({

‎packages/core/core/embed/embed-many.ts

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { Embedding, EmbeddingModel } from '../types';
2+
import { EmbeddingTokenUsage } from '../types/token-usage';
23
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
34
import { splitArray } from '../util/split-array';
45

@@ -66,6 +67,7 @@ Only applicable for HTTP-based providers.
6667
return new EmbedManyResult({
6768
values,
6869
embeddings: modelResponse.embeddings,
70+
usage: modelResponse.usage ?? { tokens: NaN },
6971
});
7072
}
7173

@@ -74,14 +76,17 @@ Only applicable for HTTP-based providers.
7476

7577
// serially embed the chunks:
7678
const embeddings = [];
79+
let tokens = 0;
80+
7781
for (const chunk of valueChunks) {
7882
const modelResponse = await retry(() =>
7983
model.doEmbed({ values: chunk, abortSignal, headers }),
8084
);
8185
embeddings.push(...modelResponse.embeddings);
86+
tokens += modelResponse.usage?.tokens ?? NaN;
8287
}
8388

84-
return new EmbedManyResult({ values, embeddings });
89+
return new EmbedManyResult({ values, embeddings, usage: { tokens } });
8590
}
8691

8792
/**
@@ -99,8 +104,18 @@ The embeddings. They are in the same order as the values.
99104
*/
100105
readonly embeddings: Array<Embedding>;
101106

102-
constructor(options: { values: Array<VALUE>; embeddings: Array<Embedding> }) {
107+
/**
108+
The embedding token usage.
109+
*/
110+
readonly usage: EmbeddingTokenUsage;
111+
112+
constructor(options: {
113+
values: Array<VALUE>;
114+
embeddings: Array<Embedding>;
115+
usage: EmbeddingTokenUsage;
116+
}) {
103117
this.values = options.values;
104118
this.embeddings = options.embeddings;
119+
this.usage = options.usage;
105120
}
106121
}

‎packages/core/core/embed/embed.test.ts

+13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ describe('result.value', () => {
3434
});
3535
});
3636

37+
describe('result.usage', () => {
38+
it('should include usage in the result', async () => {
39+
const result = await embed({
40+
model: new MockEmbeddingModelV1({
41+
doEmbed: mockEmbed([testValue], [dummyEmbedding], { tokens: 10 }),
42+
}),
43+
value: testValue,
44+
});
45+
46+
assert.deepStrictEqual(result.usage, { tokens: 10 });
47+
});
48+
});
49+
3750
describe('options.headers', () => {
3851
it('should set headers', async () => {
3952
const result = await embed({

‎packages/core/core/embed/embed.ts

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { Embedding, EmbeddingModel } from '../types';
2+
import { EmbeddingTokenUsage } from '../types/token-usage';
23
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
34

45
/**
@@ -57,6 +58,7 @@ Only applicable for HTTP-based providers.
5758
return new EmbedResult({
5859
value,
5960
embedding: modelResponse.embeddings[0],
61+
usage: modelResponse.usage ?? { tokens: NaN },
6062
rawResponse: modelResponse.rawResponse,
6163
});
6264
}
@@ -76,6 +78,11 @@ The embedding of the value.
7678
*/
7779
readonly embedding: Embedding;
7880

81+
/**
82+
The embedding token usage.
83+
*/
84+
readonly usage: EmbeddingTokenUsage;
85+
7986
/**
8087
Optional raw response data.
8188
*/
@@ -89,12 +96,12 @@ Response headers.
8996
constructor(options: {
9097
value: VALUE;
9198
embedding: Embedding;
92-
rawResponse?: {
93-
headers?: Record<string, string>;
94-
};
99+
usage: EmbeddingTokenUsage;
100+
rawResponse?: { headers?: Record<string, string> };
95101
}) {
96102
this.value = options.value;
97103
this.embedding = options.embedding;
104+
this.usage = options.usage;
98105
this.rawResponse = options.rawResponse;
99106
}
100107
}

‎packages/core/core/generate-object/generate-object.ts

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import { NoObjectGeneratedError } from '@ai-sdk/provider';
22
import { safeParseJSON } from '@ai-sdk/provider-utils';
33
import { z } from 'zod';
4-
import { TokenUsage, calculateTokenUsage } from '../generate-text/token-usage';
54
import { CallSettings } from '../prompt/call-settings';
65
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
76
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
87
import { prepareCallSettings } from '../prompt/prepare-call-settings';
98
import { Prompt } from '../prompt/prompt';
109
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
10+
import {
11+
CompletionTokenUsage,
12+
calculateCompletionTokenUsage,
13+
} from '../types/token-usage';
1114
import { convertZodToJSONSchema } from '../util/convert-zod-to-json-schema';
15+
import { prepareResponseHeaders } from '../util/prepare-response-headers';
1216
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
1317
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
14-
import { prepareResponseHeaders } from '../util/prepare-response-headers';
1518

1619
/**
1720
Generate a structured, typed object for a given prompt and schema using a language model.
@@ -99,7 +102,7 @@ Default and recommended: 'auto' (best mode for the model).
99102

100103
let result: string;
101104
let finishReason: FinishReason;
102-
let usage: Parameters<typeof calculateTokenUsage>[0];
105+
let usage: Parameters<typeof calculateCompletionTokenUsage>[0];
103106
let warnings: CallWarning[] | undefined;
104107
let rawResponse: { headers?: Record<string, string> } | undefined;
105108
let logprobs: LogProbs | undefined;
@@ -228,7 +231,7 @@ Default and recommended: 'auto' (best mode for the model).
228231
return new GenerateObjectResult({
229232
object: parseResult.value,
230233
finishReason,
231-
usage: calculateTokenUsage(usage),
234+
usage: calculateCompletionTokenUsage(usage),
232235
warnings,
233236
rawResponse,
234237
logprobs,
@@ -252,7 +255,7 @@ The reason why the generation finished.
252255
/**
253256
The token usage of the generated text.
254257
*/
255-
readonly usage: TokenUsage;
258+
readonly usage: CompletionTokenUsage;
256259

257260
/**
258261
Warnings from the model provider (e.g. unsupported settings)
@@ -278,7 +281,7 @@ Logprobs for the completion.
278281
constructor(options: {
279282
object: T;
280283
finishReason: FinishReason;
281-
usage: TokenUsage;
284+
usage: CompletionTokenUsage;
282285
warnings: CallWarning[] | undefined;
283286
rawResponse?: {
284287
headers?: Record<string, string>;

‎packages/core/core/generate-object/stream-object.ts

+14-9
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,26 @@ import {
88
isDeepEqualData,
99
parsePartialJson,
1010
} from '@ai-sdk/ui-utils';
11+
import { ServerResponse } from 'http';
1112
import { z } from 'zod';
12-
import { TokenUsage, calculateTokenUsage } from '../generate-text/token-usage';
1313
import { CallSettings } from '../prompt/call-settings';
1414
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
1515
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
1616
import { prepareCallSettings } from '../prompt/prepare-call-settings';
1717
import { Prompt } from '../prompt/prompt';
1818
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
19+
import {
20+
CompletionTokenUsage,
21+
calculateCompletionTokenUsage,
22+
} from '../types/token-usage';
1923
import {
2024
AsyncIterableStream,
2125
createAsyncIterableStream,
2226
} from '../util/async-iterable-stream';
2327
import { convertZodToJSONSchema } from '../util/convert-zod-to-json-schema';
28+
import { prepareResponseHeaders } from '../util/prepare-response-headers';
2429
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
2530
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
26-
import { prepareResponseHeaders } from '../util/prepare-response-headers';
27-
import { ServerResponse } from 'http';
2831

2932
/**
3033
Generate a structured, typed object for a given prompt and schema using a language model.
@@ -110,7 +113,7 @@ Callback that is called when the LLM response and the final object validation ar
110113
/**
111114
The token usage of the generated response.
112115
*/
113-
usage: TokenUsage;
116+
usage: CompletionTokenUsage;
114117

115118
/**
116119
The generated object (typed according to the schema). Can be undefined if the final object does not match the schema.
@@ -327,7 +330,7 @@ The generated object (typed according to the schema). Resolved when the response
327330
/**
328331
The token usage of the generated response. Resolved when the response is finished.
329332
*/
330-
readonly usage: Promise<TokenUsage>;
333+
readonly usage: Promise<CompletionTokenUsage>;
331334

332335
/**
333336
Optional raw response data.
@@ -368,13 +371,15 @@ Response headers.
368371
});
369372

370373
// initialize usage promise
371-
let resolveUsage: (value: TokenUsage | PromiseLike<TokenUsage>) => void;
372-
this.usage = new Promise<TokenUsage>(resolve => {
374+
let resolveUsage: (
375+
value: CompletionTokenUsage | PromiseLike<CompletionTokenUsage>,
376+
) => void;
377+
this.usage = new Promise<CompletionTokenUsage>(resolve => {
373378
resolveUsage = resolve;
374379
});
375380

376381
// store information for onFinish callback:
377-
let usage: TokenUsage | undefined;
382+
let usage: CompletionTokenUsage | undefined;
378383
let object: T | undefined;
379384
let error: unknown | undefined;
380385

@@ -425,7 +430,7 @@ Response headers.
425430
}
426431

427432
// store usage for promises and onFinish callback:
428-
usage = calculateTokenUsage(chunk.usage);
433+
usage = calculateCompletionTokenUsage(chunk.usage);
429434

430435
controller.enqueue({ ...chunk, usage });
431436

‎packages/core/core/generate-text/generate-text.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ import {
1616
LanguageModel,
1717
LogProbs,
1818
} from '../types';
19+
import {
20+
CompletionTokenUsage,
21+
calculateCompletionTokenUsage,
22+
} from '../types/token-usage';
1923
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
20-
import { TokenUsage, calculateTokenUsage } from './token-usage';
2124
import { ToToolCallArray, parseToolCall } from './tool-call';
2225
import { ToToolResultArray } from './tool-result';
2326

@@ -176,7 +179,7 @@ By default, it's set to 0, which will disable the feature.
176179
toolCalls: currentToolCalls,
177180
toolResults: currentToolResults,
178181
finishReason: currentModelResponse.finishReason,
179-
usage: calculateTokenUsage(currentModelResponse.usage),
182+
usage: calculateCompletionTokenUsage(currentModelResponse.usage),
180183
warnings: currentModelResponse.warnings,
181184
rawResponse: currentModelResponse.rawResponse,
182185
logprobs: currentModelResponse.logprobs,
@@ -243,7 +246,7 @@ The reason why the generation finished.
243246
/**
244247
The token usage of the generated text.
245248
*/
246-
readonly usage: TokenUsage;
249+
readonly usage: CompletionTokenUsage;
247250

248251
/**
249252
Warnings from the model provider (e.g. unsupported settings)
@@ -280,7 +283,7 @@ Logprobs for the completion.
280283
toolCalls: ToToolCallArray<TOOLS>;
281284
toolResults: ToToolResultArray<TOOLS>;
282285
finishReason: FinishReason;
283-
usage: TokenUsage;
286+
usage: CompletionTokenUsage;
284287
warnings: CallWarning[] | undefined;
285288
rawResponse?: {
286289
headers?: Record<string, string>;
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
export * from './generate-text';
22
export * from './stream-text';
3-
export type { TokenUsage } from './token-usage';

‎packages/core/core/generate-text/run-tools-transformation.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { LanguageModelV1StreamPart, NoSuchToolError } from '@ai-sdk/provider';
22
import { generateId } from '@ai-sdk/ui-utils';
33
import { CoreTool } from '../tool';
4+
import { calculateCompletionTokenUsage } from '../types/token-usage';
45
import { TextStreamPart } from './stream-text';
5-
import { calculateTokenUsage } from './token-usage';
66
import { parseToolCall } from './tool-call';
77

88
export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
@@ -131,7 +131,7 @@ export function runToolsTransformation<TOOLS extends Record<string, CoreTool>>({
131131
type: 'finish',
132132
finishReason: chunk.finishReason,
133133
logprobs: chunk.logprobs,
134-
usage: calculateTokenUsage(chunk.usage),
134+
usage: calculateCompletionTokenUsage(chunk.usage),
135135
});
136136
break;
137137
}

‎packages/core/core/generate-text/stream-text.ts

+8-6
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ import {
1818
LanguageModel,
1919
LogProbs,
2020
} from '../types';
21+
import { CompletionTokenUsage } from '../types/token-usage';
2122
import {
2223
AsyncIterableStream,
2324
createAsyncIterableStream,
2425
} from '../util/async-iterable-stream';
2526
import { prepareResponseHeaders } from '../util/prepare-response-headers';
2627
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
2728
import { runToolsTransformation } from './run-tools-transformation';
28-
import { TokenUsage } from './token-usage';
2929
import { ToToolCall } from './tool-call';
3030
import { ToToolResult } from './tool-result';
3131

@@ -109,7 +109,7 @@ The reason why the generation finished.
109109
/**
110110
The token usage of the generated response.
111111
*/
112-
usage: TokenUsage;
112+
usage: CompletionTokenUsage;
113113

114114
/**
115115
The full text that has been generated.
@@ -210,7 +210,7 @@ Warnings from the model provider (e.g. unsupported settings).
210210
/**
211211
The token usage of the generated response. Resolved when the response is finished.
212212
*/
213-
readonly usage: Promise<TokenUsage>;
213+
readonly usage: Promise<CompletionTokenUsage>;
214214

215215
/**
216216
The reason why the generation finished. Resolved when the response is finished.
@@ -260,8 +260,10 @@ Response headers.
260260
this.onFinish = onFinish;
261261

262262
// initialize usage promise
263-
let resolveUsage: (value: TokenUsage | PromiseLike<TokenUsage>) => void;
264-
this.usage = new Promise<TokenUsage>(resolve => {
263+
let resolveUsage: (
264+
value: CompletionTokenUsage | PromiseLike<CompletionTokenUsage>,
265+
) => void;
266+
this.usage = new Promise<CompletionTokenUsage>(resolve => {
265267
resolveUsage = resolve;
266268
});
267269

@@ -297,7 +299,7 @@ Response headers.
297299

298300
// store information for onFinish callback:
299301
let finishReason: FinishReason | undefined;
300-
let usage: TokenUsage | undefined;
302+
let usage: CompletionTokenUsage | undefined;
301303
let text = '';
302304
const toolCalls: ToToolCall<TOOLS>[] = [];
303305
const toolResults: ToToolResult<TOOLS>[] = [];

‎packages/core/core/test/mock-embedding-model-v1.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { EmbeddingModelV1 } from '@ai-sdk/provider';
22
import { Embedding } from '../types';
3+
import { EmbeddingTokenUsage } from '../types/token-usage';
34

45
export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
56
readonly specificationVersion = 'v1';
@@ -35,10 +36,11 @@ export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
3536
export function mockEmbed<VALUE>(
3637
expectedValues: Array<VALUE>,
3738
embeddings: Array<Embedding>,
39+
usage?: EmbeddingTokenUsage,
3840
): EmbeddingModelV1<VALUE>['doEmbed'] {
3941
return async ({ values }) => {
4042
assert.deepStrictEqual(expectedValues, values);
41-
return { embeddings };
43+
return { embeddings, usage };
4244
};
4345
}
4446

‎packages/core/core/types/index.ts

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
import type { CompletionTokenUsage as CompletionTokenUsageOriginal } from './token-usage';
2+
13
export * from './embedding-model';
24
export * from './errors';
35
export * from './language-model';
6+
7+
/**
8+
* @deprecated Use CompletionTokenUsage instead.
9+
*/
10+
export type TokenUsage = CompletionTokenUsageOriginal;
11+
export type CompletionTokenUsage = CompletionTokenUsageOriginal;
12+
13+
export type { EmbeddingTokenUsage } from './token-usage';

‎packages/core/core/generate-text/token-usage.ts renamed to ‎packages/core/core/types/token-usage.ts

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
Represents the number of tokens used in a prompt and completion.
33
*/
4-
export type TokenUsage = {
4+
export type CompletionTokenUsage = {
55
/**
66
The number of tokens used in the prompt
77
*/
@@ -18,10 +18,20 @@ The total number of tokens used (promptTokens + completionTokens).
1818
totalTokens: number;
1919
};
2020

21-
export function calculateTokenUsage(usage: {
21+
/**
22+
Represents the number of tokens used in an embedding.
23+
*/
24+
export type EmbeddingTokenUsage = {
25+
/**
26+
The number of tokens used in the embedding.
27+
*/
28+
tokens: number;
29+
};
30+
31+
export function calculateCompletionTokenUsage(usage: {
2232
promptTokens: number;
2333
completionTokens: number;
24-
}): TokenUsage {
34+
}): CompletionTokenUsage {
2535
return {
2636
promptTokens: usage.promptTokens,
2737
completionTokens: usage.completionTokens,

‎packages/core/rsc/stream-ui/stream-ui.tsx

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ import { prepareCallSettings } from '../../core/prompt/prepare-call-settings';
1414
import { prepareToolsAndToolChoice } from '../../core/prompt/prepare-tools-and-tool-choice';
1515
import { Prompt } from '../../core/prompt/prompt';
1616
import { CallWarning, CoreToolChoice, FinishReason } from '../../core/types';
17+
import {
18+
CompletionTokenUsage,
19+
calculateCompletionTokenUsage,
20+
} from '../../core/types/token-usage';
1721
import { retryWithExponentialBackoff } from '../../core/util/retry-with-exponential-backoff';
1822
import { createStreamableUI } from '../streamable';
1923
import { createResolvablePromise } from '../utils';
20-
import {
21-
TokenUsage,
22-
calculateTokenUsage,
23-
} from '../../core/generate-text/token-usage';
2424

2525
type Streamable = ReactNode | Promise<ReactNode>;
2626

@@ -123,7 +123,7 @@ export async function streamUI<
123123
/**
124124
* The token usage of the generated response.
125125
*/
126-
usage: TokenUsage;
126+
usage: CompletionTokenUsage;
127127
/**
128128
* The final ui node that was generated.
129129
*/
@@ -350,7 +350,7 @@ export async function streamUI<
350350
case 'finish': {
351351
onFinish?.({
352352
finishReason: value.finishReason,
353-
usage: calculateTokenUsage(value.usage),
353+
usage: calculateCompletionTokenUsage(value.usage),
354354
value: ui.value,
355355
warnings: result.warnings,
356356
rawResponse: result.rawResponse,

‎packages/mistral/src/mistral-embedding-model.test.ts

+14-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ describe('doEmbed', () => {
1818

1919
function prepareJsonResponse({
2020
embeddings = dummyEmbeddings,
21+
usage = { prompt_tokens: 8, total_tokens: 8 },
2122
}: {
2223
embeddings?: EmbeddingModelV1Embedding[];
24+
usage?: { prompt_tokens: number; total_tokens: number };
2325
} = {}) {
2426
server.responseBodyJson = {
2527
id: 'b322cfc2b9d34e2f8e14fc99874faee5',
@@ -30,7 +32,7 @@ describe('doEmbed', () => {
3032
index: i,
3133
})),
3234
model: 'mistral-embed',
33-
usage: { prompt_tokens: 8, total_tokens: 8, completion_tokens: 0 },
35+
usage,
3436
};
3537
}
3638

@@ -42,6 +44,16 @@ describe('doEmbed', () => {
4244
expect(embeddings).toStrictEqual(dummyEmbeddings);
4345
});
4446

47+
it('should extract usage', async () => {
48+
prepareJsonResponse({
49+
usage: { prompt_tokens: 20, total_tokens: 20 },
50+
});
51+
52+
const { usage } = await model.doEmbed({ values: testValues });
53+
54+
expect(usage).toStrictEqual({ tokens: 20 });
55+
});
56+
4557
it('should expose the raw response headers', async () => {
4658
prepareJsonResponse();
4759

@@ -53,7 +65,7 @@ describe('doEmbed', () => {
5365

5466
expect(rawResponse?.headers).toStrictEqual({
5567
// default headers:
56-
'content-length': '289',
68+
'content-length': '267',
5769
'content-type': 'application/json',
5870

5971
// custom header

‎packages/mistral/src/mistral-embedding-model.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ export class MistralEmbeddingModel implements EmbeddingModelV1<string> {
8686

8787
return {
8888
embeddings: response.data.map(item => item.embedding),
89+
usage: response.usage
90+
? { tokens: response.usage.prompt_tokens }
91+
: undefined,
8992
rawResponse: { headers: responseHeaders },
9093
};
9194
}
@@ -94,9 +97,6 @@ export class MistralEmbeddingModel implements EmbeddingModelV1<string> {
9497
// minimal version of the schema, focussed on what is needed for the implementation
9598
// this approach limits breakages when the API changes and increases efficiency
9699
const MistralTextEmbeddingResponseSchema = z.object({
97-
data: z.array(
98-
z.object({
99-
embedding: z.array(z.number()),
100-
}),
101-
),
100+
data: z.array(z.object({ embedding: z.array(z.number()) })),
101+
usage: z.object({ prompt_tokens: z.number() }).nullish(),
102102
});

‎packages/openai/src/openai-embedding-model.test.ts

+13-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ describe('doEmbed', () => {
1818

1919
function prepareJsonResponse({
2020
embeddings = dummyEmbeddings,
21+
usage = { prompt_tokens: 8, total_tokens: 8 },
2122
}: {
2223
embeddings?: EmbeddingModelV1Embedding[];
24+
usage?: { prompt_tokens: number; total_tokens: number };
2325
} = {}) {
2426
server.responseBodyJson = {
2527
object: 'list',
@@ -29,7 +31,7 @@ describe('doEmbed', () => {
2931
embedding,
3032
})),
3133
model: 'text-embedding-3-large',
32-
usage: { prompt_tokens: 8, total_tokens: 8 },
34+
usage,
3335
};
3436
}
3537

@@ -60,6 +62,16 @@ describe('doEmbed', () => {
6062
});
6163
});
6264

65+
it('should extract usage', async () => {
66+
prepareJsonResponse({
67+
usage: { prompt_tokens: 20, total_tokens: 20 },
68+
});
69+
70+
const { usage } = await model.doEmbed({ values: testValues });
71+
72+
expect(usage).toStrictEqual({ tokens: 20 });
73+
});
74+
6375
it('should pass the model and the values', async () => {
6476
prepareJsonResponse();
6577

‎packages/openai/src/openai-embedding-model.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ export class OpenAIEmbeddingModel implements EmbeddingModelV1<string> {
8989

9090
return {
9191
embeddings: response.data.map(item => item.embedding),
92+
usage: response.usage
93+
? { tokens: response.usage.prompt_tokens }
94+
: undefined,
9295
rawResponse: { headers: responseHeaders },
9396
};
9497
}
@@ -97,9 +100,6 @@ export class OpenAIEmbeddingModel implements EmbeddingModelV1<string> {
97100
// minimal version of the schema, focussed on what is needed for the implementation
98101
// this approach limits breakages when the API changes and increases efficiency
99102
const openaiTextEmbeddingResponseSchema = z.object({
100-
data: z.array(
101-
z.object({
102-
embedding: z.array(z.number()),
103-
}),
104-
),
103+
data: z.array(z.object({ embedding: z.array(z.number()) })),
104+
usage: z.object({ prompt_tokens: z.number() }).nullish(),
105105
});

‎packages/provider/src/embedding-model/v1/embedding-model-v1.ts

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ Generated embeddings. They are in the same order as the input values.
6666
*/
6767
embeddings: Array<EmbeddingModelV1Embedding>;
6868

69+
/**
70+
Token usage. We only have input tokens for embeddings.
71+
*/
72+
usage?: { tokens: number };
73+
6974
/**
7075
Optional raw response information for debugging purposes.
7176
*/

0 commit comments

Comments
 (0)
Please sign in to comment.