Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

infer errors #5554

Draft
wants to merge 20 commits into
base: next
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TRPC_ERROR_CODE_KEY } from '../rpc/codes';
import type { Overwrite, TypeError } from '../types';
import { isObject } from '../utils';

class UnknownCauseError extends Error {
Expand Down Expand Up @@ -31,7 +32,9 @@ export function getCauseFromUnknown(cause: unknown): Error | undefined {
return undefined;
}

export function getTRPCErrorFromUnknown(cause: unknown): TRPCError {
export function getTRPCErrorFromUnknown(
cause: unknown,
): TRPCError | TypedTRPCError<unknown> {
if (cause instanceof TRPCError) {
return cause;
}
Expand All @@ -53,17 +56,19 @@ export function getTRPCErrorFromUnknown(cause: unknown): TRPCError {
return trpcError;
}

type TRPCErrorOptions = {
message?: string;
code: TRPC_ERROR_CODE_KEY;
cause?: unknown;
};

export class TRPCError extends Error {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore override doesn't work in all environments due to "This member cannot have an 'override' modifier because it is not declared in the base class 'Error'"
public override readonly cause?: Error;
public readonly code;

constructor(opts: {
message?: string;
code: TRPC_ERROR_CODE_KEY;
cause?: unknown;
}) {
constructor(opts: TRPCErrorOptions) {
const cause = getCauseFromUnknown(opts.cause);
const message = opts.message ?? cause?.message ?? opts.code;

Expand All @@ -80,3 +85,64 @@ export class TRPCError extends Error {
}
}
}

export class TRPCInputValidationError extends TRPCError {
constructor(cause: unknown) {
super({
code: 'BAD_REQUEST',
cause,
});
}
}

export const trpcTypedErrorSymbol = Symbol('errorSymbol');

export type TypedTRPCError<TData> = Overwrite<TRPCError, TData> & {
[trpcTypedErrorSymbol]: typeof trpcTypedErrorSymbol;
};

export function trpcError<
TData extends Partial<TRPCErrorOptions> & Record<string, unknown>,
>(
data: TData extends {
stack?: any;
}
? TypeError<"key 'stack' should not be passed to trpcError">
: TData,
) {
const {
code = 'BAD_REQUEST',
cause,
message,
stack: _,
...rest
} = data as TData;

const error = new TRPCError({
code,
cause,
message,
}) as TypedTRPCError<
Overwrite<
TData,
{
code: TData['code'] extends TRPC_ERROR_CODE_KEY
? TData['code']
: 'BAD_REQUEST';
}
>
>;

error[trpcTypedErrorSymbol] = trpcTypedErrorSymbol;
for (const [key, value] of Object.entries(rest)) {
(error as any)[key] = value;
}

return error;
}

export function isTRPCError(
err: unknown,
): err is TRPCError | TypedTRPCError<any> {
return err instanceof TRPCError;
}
47 changes: 35 additions & 12 deletions packages/server/src/unstable-core-do-not-import/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { TRPCError } from './error/TRPCError';
import type { TypedTRPCError } from './error/TRPCError';
import { TRPCError, TRPCInputValidationError } from './error/TRPCError';
import type { ParseFn } from './parser';
import type { ProcedureType } from './procedure';
import type { GetRawInputFn, Overwrite, Simplify } from './types';
import type { GetRawInputFn, MaybePromise, Overwrite, Simplify } from './types';
import { isObject } from './utils';

/** @internal */
Expand Down Expand Up @@ -56,7 +57,10 @@ export interface MiddlewareBuilder<
TMeta,
TContextOverrides,
$ContextOverridesOut,
TInputOut
TInputOut,
// FIXME
never,
never
>
| MiddlewareBuilder<
Overwrite<TContext, TContextOverrides>,
Expand All @@ -79,7 +83,10 @@ export interface MiddlewareBuilder<
TMeta,
TContextOverrides,
object,
TInputOut
TInputOut,
// FIXME
any,
any
>[];
}

Expand All @@ -92,6 +99,8 @@ export type MiddlewareFunction<
TContextOverridesIn,
$ContextOverridesOut,
TInputOut,
TInferErrors extends boolean,
$ErrorOutput,
> = {
(opts: {
ctx: Simplify<Overwrite<TContext, TContextOverridesIn>>;
Expand All @@ -110,11 +119,22 @@ export type MiddlewareFunction<
MiddlewareResult<TContextOverridesIn>
>;
};
}): Promise<MiddlewareResult<$ContextOverridesOut>>;
}): MaybePromise<
| MiddlewareResult<$ContextOverridesOut>
| (true extends TInferErrors ? TypedTRPCError<$ErrorOutput> : never)
>;
_type?: string | undefined;
};

export type AnyMiddlewareFunction = MiddlewareFunction<any, any, any, any, any>;
export type AnyMiddlewareFunction = MiddlewareFunction<
any,
any,
any,
any,
any,
any,
any
>;
export type AnyMiddlewareBuilder = MiddlewareBuilder<any, any, any, any>;
/**
* @internal
Expand All @@ -140,13 +160,19 @@ export function createMiddlewareFactory<
};
}

function createMiddleware<$ContextOverrides>(
function createMiddleware<
$ContextOverrides,
TInferErrors extends boolean,
$Error,
>(
fn: MiddlewareFunction<
TContext,
TMeta,
object,
$ContextOverrides,
TInputOut
TInputOut,
TInferErrors,
$Error
>,
): MiddlewareBuilder<TContext, TMeta, $ContextOverrides, TInputOut> {
return createMiddlewareInner([fn]);
Expand Down Expand Up @@ -187,10 +213,7 @@ export function createInputMiddleware<TInput>(parse: ParseFn<TInput>) {
try {
parsedInput = await parse(rawInput);
} catch (cause) {
throw new TRPCError({
code: 'BAD_REQUEST',
cause,
});
throw new TRPCInputValidationError(cause);
}

// Multiple input parsers
Expand Down