Skip to content

Commit

Permalink
ai/core: export language model types (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Apr 28, 2024
1 parent b4c68ec commit 41d5736
Show file tree
Hide file tree
Showing 12 changed files with 119 additions and 59 deletions.
6 changes: 6 additions & 0 deletions .changeset/afraid-panthers-sing.md
@@ -0,0 +1,6 @@
---
'@ai-sdk/provider': patch
'ai': patch
---

ai/core: re-expose language model types.
35 changes: 15 additions & 20 deletions packages/core/core/generate-object/generate-object.ts
@@ -1,10 +1,4 @@
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
NoTextGeneratedError,
} from '@ai-sdk/provider';
import { NoObjectGeneratedError } from '@ai-sdk/provider';
import { safeParseJSON } from '@ai-sdk/provider-utils';
import { z } from 'zod';
import { TokenUsage, calculateTokenUsage } from '../generate-text/token-usage';
Expand All @@ -13,6 +7,7 @@ import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-mode
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
import { convertZodToJSONSchema } from '../util/convert-zod-to-json-schema';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
Expand Down Expand Up @@ -68,7 +63,7 @@ export async function experimental_generateObject<T>({
/**
The language model to use.
*/
model: LanguageModelV1;
model: LanguageModel;

/**
The schema of the object that the model should generate.
Expand All @@ -91,11 +86,11 @@ Default and recommended: 'auto' (best mode for the model).
}

let result: string;
let finishReason: LanguageModelV1FinishReason;
let finishReason: FinishReason;
let usage: Parameters<typeof calculateTokenUsage>[0];
let warnings: LanguageModelV1CallWarning[] | undefined;
let warnings: CallWarning[] | undefined;
let rawResponse: { headers?: Record<string, string> } | undefined;
let logprobs: LanguageModelV1LogProbs | undefined;
let logprobs: LogProbs | undefined;

switch (mode) {
case 'json': {
Expand All @@ -116,7 +111,7 @@ Default and recommended: 'auto' (best mode for the model).
});

if (generateResult.text === undefined) {
throw new NoTextGeneratedError();
throw new NoObjectGeneratedError();
}

result = generateResult.text;
Expand Down Expand Up @@ -147,7 +142,7 @@ Default and recommended: 'auto' (best mode for the model).
);

if (generateResult.text === undefined) {
throw new NoTextGeneratedError();
throw new NoObjectGeneratedError();
}

result = generateResult.text;
Expand Down Expand Up @@ -188,7 +183,7 @@ Default and recommended: 'auto' (best mode for the model).
const functionArgs = generateResult.toolCalls?.[0]?.args;

if (functionArgs === undefined) {
throw new NoTextGeneratedError();
throw new NoObjectGeneratedError();
}

result = functionArgs;
Expand Down Expand Up @@ -239,7 +234,7 @@ The generated object (typed according to the schema).
/**
The reason why the generation finished.
*/
readonly finishReason: LanguageModelV1FinishReason;
readonly finishReason: FinishReason;

/**
The token usage of the generated text.
Expand All @@ -249,7 +244,7 @@ The token usage of the generated text.
/**
Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;
readonly warnings: CallWarning[] | undefined;

/**
Optional raw response data.
Expand All @@ -265,17 +260,17 @@ Response headers.
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
*/
readonly logprobs: LanguageModelV1LogProbs | undefined;
readonly logprobs: LogProbs | undefined;

constructor(options: {
object: T;
finishReason: LanguageModelV1FinishReason;
finishReason: FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
warnings: CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
logprobs: LogProbs | undefined;
}) {
this.object = options.object;
this.finishReason = options.finishReason;
Expand Down
17 changes: 7 additions & 10 deletions packages/core/core/generate-object/stream-object.ts
@@ -1,17 +1,15 @@
import {
LanguageModelV1,
LanguageModelV1CallOptions,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
LanguageModelV1StreamPart,
} from '@ai-sdk/provider';
import { z } from 'zod';
import { calculateTokenUsage } from '../generate-text/token-usage';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
import {
AsyncIterableStream,
createAsyncIterableStream,
Expand All @@ -22,7 +20,6 @@ import { isDeepEqualData } from '../util/is-deep-equal-data';
import { parsePartialJson } from '../util/parse-partial-json';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
import { injectJsonSchemaIntoSystem } from './inject-json-schema-into-system';
import { calculateTokenUsage } from '../generate-text/token-usage';

/**
Generate a structured, typed object for a given prompt and schema using a language model.
Expand Down Expand Up @@ -75,7 +72,7 @@ export async function experimental_streamObject<T>({
/**
The language model to use.
*/
model: LanguageModelV1;
model: LanguageModel;

/**
The schema of the object that the model should generate.
Expand Down Expand Up @@ -231,8 +228,8 @@ export type ObjectStreamPartInput =
}
| {
type: 'finish';
finishReason: LanguageModelV1FinishReason;
logprobs?: LanguageModelV1LogProbs;
finishReason: FinishReason;
logprobs?: LogProbs;
usage: {
promptTokens: number;
completionTokens: number;
Expand All @@ -258,7 +255,7 @@ export class StreamObjectResult<T> {
/**
Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;
readonly warnings: CallWarning[] | undefined;

/**
Optional raw response data.
Expand All @@ -276,7 +273,7 @@ Response headers.
rawResponse,
}: {
stream: ReadableStream<string | ObjectStreamPartInput>;
warnings: LanguageModelV1CallWarning[] | undefined;
warnings: CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
Expand Down
21 changes: 8 additions & 13 deletions packages/core/core/generate-text/generate-text.ts
@@ -1,15 +1,10 @@
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
} from '@ai-sdk/provider';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { ExperimentalTool } from '../tool/tool';
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
import { convertZodToJSONSchema } from '../util/convert-zod-to-json-schema';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
import { TokenUsage, calculateTokenUsage } from './token-usage';
Expand Down Expand Up @@ -66,7 +61,7 @@ export async function experimental_generateText<
/**
The language model to use.
*/
model: LanguageModelV1;
model: LanguageModel;

/**
The tools that the model can call. The model needs to support calling tools.
Expand Down Expand Up @@ -177,7 +172,7 @@ The results of the tool calls.
/**
The reason why the generation finished.
*/
readonly finishReason: LanguageModelV1FinishReason;
readonly finishReason: FinishReason;

/**
The token usage of the generated text.
Expand All @@ -187,7 +182,7 @@ The token usage of the generated text.
/**
Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;
readonly warnings: CallWarning[] | undefined;

/**
Optional raw response data.
Expand All @@ -203,19 +198,19 @@ Response headers.
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
*/
readonly logprobs: LanguageModelV1LogProbs | undefined;
readonly logprobs: LogProbs | undefined;

constructor(options: {
text: string;
toolCalls: ToToolCallArray<TOOLS>;
toolResults: ToToolResultArray<TOOLS>;
finishReason: LanguageModelV1FinishReason;
finishReason: FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
warnings: CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
logprobs: LogProbs | undefined;
}) {
this.text = options.text;
this.toolCalls = options.toolCalls;
Expand Down
17 changes: 6 additions & 11 deletions packages/core/core/generate-text/stream-text.ts
@@ -1,9 +1,3 @@
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
} from '@ai-sdk/provider';
import { ServerResponse } from 'node:http';
import {
AIStreamCallbacksAndOptions,
Expand All @@ -17,6 +11,7 @@ import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { ExperimentalTool } from '../tool';
import { CallWarning, FinishReason, LanguageModel, LogProbs } from '../types';
import {
AsyncIterableStream,
createAsyncIterableStream,
Expand Down Expand Up @@ -77,7 +72,7 @@ export async function experimental_streamText<
/**
The language model to use.
*/
model: LanguageModelV1;
model: LanguageModel;

/**
The tools that the model can call. The model needs to support calling tools.
Expand Down Expand Up @@ -134,8 +129,8 @@ export type TextStreamPart<TOOLS extends Record<string, ExperimentalTool>> =
} & ToToolResult<TOOLS>)
| {
type: 'finish';
finishReason: LanguageModelV1FinishReason;
logprobs?: LanguageModelV1LogProbs;
finishReason: FinishReason;
logprobs?: LogProbs;
usage: {
promptTokens: number;
completionTokens: number;
Expand All @@ -152,7 +147,7 @@ export class StreamTextResult<TOOLS extends Record<string, ExperimentalTool>> {
/**
Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;
readonly warnings: CallWarning[] | undefined;

/**
Optional raw response data.
Expand All @@ -170,7 +165,7 @@ Response headers.
rawResponse,
}: {
stream: ReadableStream<TextStreamPart<TOOLS>>;
warnings: LanguageModelV1CallWarning[] | undefined;
warnings: CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
Expand Down
1 change: 1 addition & 0 deletions packages/core/core/index.ts
Expand Up @@ -2,4 +2,5 @@ export * from './generate-object';
export * from './generate-text';
export * from './prompt';
export * from './tool';
export * from './types';
export * from './util/deep-partial';
18 changes: 18 additions & 0 deletions packages/core/core/types/errors.ts
@@ -0,0 +1,18 @@
export {
APICallError,
EmptyResponseBodyError,
InvalidArgumentError,
InvalidDataContentError,
InvalidPromptError,
InvalidResponseDataError,
InvalidToolArgumentsError,
JSONParseError,
LoadAPIKeyError,
NoObjectGeneratedError,
NoSuchToolError,
RetryError,
ToolCallParseError,
TypeValidationError,
UnsupportedFunctionalityError,
UnsupportedJSONSchemaError,
} from '@ai-sdk/provider';
2 changes: 2 additions & 0 deletions packages/core/core/types/index.ts
@@ -0,0 +1,2 @@
export * from './errors';
export * from './language-model';
35 changes: 35 additions & 0 deletions packages/core/core/types/language-model.ts
@@ -0,0 +1,35 @@
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
} from '@ai-sdk/provider';

/**
Language model that is used by the AI SDK Core functions.
*/
export type LanguageModel = LanguageModelV1;

/**
Reason why a language model finished generating a response.
Can be one of the following:
- `stop`: model generated stop sequence
- `length`: model generated maximum number of tokens
- `content-filter`: content filter violation stopped the model
- `tool-calls`: model triggered tool calls
- `error`: model stopped because of an error
- `other`: model stopped for other reasons
*/
export type FinishReason = LanguageModelV1FinishReason;

/**
Log probabilities for each token and its top log probabilities.
*/
export type LogProbs = LanguageModelV1LogProbs;

/**
Warning from the model provider for this call. The call will proceed, but e.g.
some settings might not be supported, which can lead to suboptimal results.
*/
export type CallWarning = LanguageModelV1CallWarning;
12 changes: 7 additions & 5 deletions packages/provider/src/errors/no-object-generated-error.ts
@@ -1,14 +1,16 @@
export class NoTextGeneratedError extends Error {
export class NoObjectGeneratedError extends Error {
readonly cause: unknown;

constructor() {
super(`No text generated.`);
super(`No object generated.`);

this.name = 'AI_NoTextGeneratedError';
this.name = 'AI_NoObjectGeneratedError';
}

static isNoTextGeneratedError(error: unknown): error is NoTextGeneratedError {
return error instanceof Error && error.name === 'AI_NoTextGeneratedError';
static isNoTextGeneratedError(
error: unknown,
): error is NoObjectGeneratedError {
return error instanceof Error && error.name === 'AI_NoObjectGeneratedError';
}

toJSON() {
Expand Down

0 comments on commit 41d5736

Please sign in to comment.