Skip to content

Commit

Permalink
ai/core: expose raw response headers (#1417)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Apr 23, 2024
1 parent d6431ae commit 25f3350
Show file tree
Hide file tree
Showing 27 changed files with 543 additions and 133 deletions.
11 changes: 11 additions & 0 deletions .changeset/short-seas-flash.md
@@ -0,0 +1,11 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/provider': patch
'@ai-sdk/mistral': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'ai': patch
---

ai/core: add support for getting raw response headers.
24 changes: 24 additions & 0 deletions examples/ai-core/src/stream-text/openai-response-headers.ts
@@ -0,0 +1,24 @@
import { openai } from '@ai-sdk/openai';
import { experimental_streamText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_streamText({
model: openai('gpt-3.5-turbo'),
maxTokens: 512,
temperature: 0.3,
maxRetries: 5,
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(`Request ID: ${result.rawResponse?.headers?.['x-request-id']}`);
console.log();

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}
}

main().catch(console.error);
51 changes: 47 additions & 4 deletions packages/anthropic/src/anthropic-messages-language-model.test.ts
Expand Up @@ -11,10 +11,7 @@ const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
];

const provider = createAnthropic({
apiKey: 'test-api-key',
});

const provider = createAnthropic({ apiKey: 'test-api-key' });
const model = provider.chat('claude-3-haiku-20240307');

describe('doGenerate', () => {
Expand Down Expand Up @@ -181,6 +178,28 @@ describe('doGenerate', () => {
});
});

it('should expose the raw response headers', async () => {
prepareJsonResponse({});

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-type': 'application/json',

// custom header
'test-header': 'test-value',
});
});

it('should pass the model and the messages', async () => {
prepareJsonResponse({});

Expand Down Expand Up @@ -279,6 +298,30 @@ describe('doStream', () => {
]);
});

it('should expose the raw response headers', async () => {
prepareStreamResponse({ content: [] });

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-type': 'text/event-stream',
'cache-control': 'no-cache',
connection: 'keep-alive',

// custom header
'test-header': 'test-value',
});
});

it('should pass the messages and the model', async () => {
prepareStreamResponse({ content: [] });

Expand Down
6 changes: 4 additions & 2 deletions packages/anthropic/src/anthropic-messages-language-model.ts
Expand Up @@ -164,7 +164,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
const { args, warnings } = this.getArgs(options);

const response = await postJsonToApi({
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}/messages`,
headers: this.config.headers(),
body: args,
Expand Down Expand Up @@ -210,6 +210,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
completionTokens: response.usage.output_tokens,
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
};
}
Expand All @@ -219,7 +220,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const { args, warnings } = this.getArgs(options);

const response = await postJsonToApi({
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}/messages`,
headers: this.config.headers(),
body: {
Expand Down Expand Up @@ -296,6 +297,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
}),
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
};
}
Expand Down
19 changes: 19 additions & 0 deletions packages/core/core/generate-object/generate-object.ts
Expand Up @@ -94,6 +94,7 @@ Default and recommended: 'auto' (best mode for the model).
let finishReason: LanguageModelV1FinishReason;
let usage: Parameters<typeof calculateTokenUsage>[0];
let warnings: LanguageModelV1CallWarning[] | undefined;
let rawResponse: { headers?: Record<string, string> } | undefined;
let logprobs: LanguageModelV1LogProbs | undefined;

switch (mode) {
Expand Down Expand Up @@ -122,6 +123,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand Down Expand Up @@ -152,6 +154,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand Down Expand Up @@ -192,6 +195,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason = generateResult.finishReason;
usage = generateResult.usage;
warnings = generateResult.warnings;
rawResponse = generateResult.rawResponse;
logprobs = generateResult.logprobs;

break;
Expand All @@ -218,6 +222,7 @@ Default and recommended: 'auto' (best mode for the model).
finishReason,
usage: calculateTokenUsage(usage),
warnings,
rawResponse,
logprobs,
});
}
Expand Down Expand Up @@ -246,6 +251,16 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

/**
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
Expand All @@ -257,12 +272,16 @@ Logprobs for the completion.
finishReason: LanguageModelV1FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
}) {
this.object = options.object;
this.finishReason = options.finishReason;
this.usage = options.usage;
this.warnings = options.warnings;
this.rawResponse = options.rawResponse;
this.logprobs = options.logprobs;
}
}
16 changes: 16 additions & 0 deletions packages/core/core/generate-object/stream-object.ts
Expand Up @@ -220,6 +220,7 @@ Default and recommended: 'auto' (best mode for the model).
return new StreamObjectResult({
stream: result.stream.pipeThrough(new TransformStream(transformer)),
warnings: result.warnings,
rawResponse: result.rawResponse,
});
}

Expand Down Expand Up @@ -259,15 +260,30 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

constructor({
stream,
warnings,
rawResponse,
}: {
stream: ReadableStream<string | ObjectStreamPartInput>;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
}) {
this.originalStream = stream;
this.warnings = warnings;
this.rawResponse = rawResponse;
}

get partialObjectStream(): AsyncIterableStream<DeepPartial<T>> {
Expand Down
15 changes: 15 additions & 0 deletions packages/core/core/generate-text/generate-text.ts
Expand Up @@ -116,6 +116,7 @@ The tools that the model can call. The model needs to support calling tools.
finishReason: modelResponse.finishReason,
usage: calculateTokenUsage(modelResponse.usage),
warnings: modelResponse.warnings,
rawResponse: modelResponse.rawResponse,
logprobs: modelResponse.logprobs,
});
}
Expand Down Expand Up @@ -188,6 +189,16 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

/**
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
Expand All @@ -201,6 +212,9 @@ Logprobs for the completion.
finishReason: LanguageModelV1FinishReason;
usage: TokenUsage;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
logprobs: LanguageModelV1LogProbs | undefined;
}) {
this.text = options.text;
Expand All @@ -209,6 +223,7 @@ Logprobs for the completion.
this.finishReason = options.finishReason;
this.usage = options.usage;
this.warnings = options.warnings;
this.rawResponse = options.rawResponse;
this.logprobs = options.logprobs;
}
}
3 changes: 1 addition & 2 deletions packages/core/core/generate-text/stream-text.test.ts
Expand Up @@ -4,9 +4,8 @@ import { convertArrayToReadableStream } from '../test/convert-array-to-readable-
import { convertAsyncIterableToArray } from '../test/convert-async-iterable-to-array';
import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array';
import { MockLanguageModelV1 } from '../test/mock-language-model-v1';
import { experimental_streamText } from './stream-text';
import { ServerResponse } from 'node:http';
import { createMockServerResponse } from '../test/mock-server-response';
import { experimental_streamText } from './stream-text';

describe('result.textStream', () => {
it('should send text deltas', async () => {
Expand Down
18 changes: 17 additions & 1 deletion packages/core/core/generate-text/stream-text.ts
Expand Up @@ -85,7 +85,7 @@ The tools that the model can call. The model needs to support calling tools.
}): Promise<StreamTextResult<TOOLS>> {
const retry = retryWithExponentialBackoff({ maxRetries });
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const { stream, warnings } = await retry(() =>
const { stream, warnings, rawResponse } = await retry(() =>
model.doStream({
mode: {
type: 'regular',
Expand All @@ -112,6 +112,7 @@ The tools that the model can call. The model needs to support calling tools.
generatorStream: stream,
}),
warnings,
rawResponse,
});
}

Expand Down Expand Up @@ -152,15 +153,30 @@ Warnings from the model provider (e.g. unsupported settings)
*/
readonly warnings: LanguageModelV1CallWarning[] | undefined;

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

constructor({
stream,
warnings,
rawResponse,
}: {
stream: ReadableStream<TextStreamPart<TOOLS>>;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse?: {
headers?: Record<string, string>;
};
}) {
this.originalStream = stream;
this.warnings = warnings;
this.rawResponse = rawResponse;
}

/**
Expand Down

0 comments on commit 25f3350

Please sign in to comment.