Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subscribe: simplify by rewriting as async functions #3024

Merged
merged 1 commit into from Apr 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 53 additions & 61 deletions src/subscription/subscribe.js
Expand Up @@ -57,7 +57,7 @@ export type SubscriptionArgs = {|
*
* Accepts either an object with named arguments, or individual arguments.
*/
export function subscribe(
export async function subscribe(
args: SubscriptionArgs,
): Promise<AsyncGenerator<ExecutionResult, void, void> | ExecutionResult> {
const {
Expand All @@ -71,7 +71,8 @@ export function subscribe(
subscribeFieldResolver,
} = args;

const sourcePromise = createSourceEventStream(
// $FlowFixMe[incompatible-call]
const resultOrStream = await createSourceEventStream(
schema,
document,
rootValue,
Expand All @@ -81,6 +82,10 @@ export function subscribe(
subscribeFieldResolver,
);

if (!isAsyncIterable(resultOrStream)) {
return resultOrStream;
}

// For each payload yielded from a subscription, map it over the normal
// GraphQL `execute` function, with `payload` as the rootValue.
// This implements the "MapSourceToResponseEvent" algorithm described in
Expand All @@ -98,30 +103,13 @@ export function subscribe(
fieldResolver,
});

// Resolve the Source Stream, then map every source value to a
// ExecutionResult value as described above.
return sourcePromise.then((resultOrStream) =>
// Note: Flow can't refine isAsyncIterable, so explicit casts are used.
isAsyncIterable(resultOrStream)
? mapAsyncIterator(
resultOrStream,
mapSourceToResponse,
reportGraphQLError,
)
: ((resultOrStream: any): ExecutionResult),
);
}

/**
* This function checks if the error is a GraphQLError. If it is, report it as
* an ExecutionResult, containing only errors and no data. Otherwise treat the
* error as a system-class error and re-throw it.
*/
function reportGraphQLError(error: mixed): ExecutionResult {
if (error instanceof GraphQLError) {
return { errors: [error] };
}
throw error;
// Map every source value to a ExecutionResult value as described above.
return mapAsyncIterator(resultOrStream, mapSourceToResponse, (error) => {
if (error instanceof GraphQLError) {
return { errors: [error] };
}
throw error;
});
}

/**
Expand Down Expand Up @@ -152,7 +140,7 @@ function reportGraphQLError(error: mixed): ExecutionResult {
* or otherwise separating these two steps. For more on this, see the
* "Supporting Subscriptions at Scale" information in the GraphQL specification.
*/
export function createSourceEventStream(
export async function createSourceEventStream(
schema: GraphQLSchema,
document: DocumentNode,
rootValue?: mixed,
Expand All @@ -165,9 +153,8 @@ export function createSourceEventStream(
// developer mistake which should throw an early error.
assertValidExecutionArguments(schema, document, variableValues);

return new Promise((resolve) => {
// If a valid context cannot be created due to incorrect arguments,
// this will throw an error.
try {
// If a valid context cannot be created due to incorrect arguments, this will throw an error.
const exeContext = buildExecutionContext(
schema,
document,
Expand All @@ -178,18 +165,35 @@ export function createSourceEventStream(
fieldResolver,
);

resolve(
// Return early errors if execution context failed.
Array.isArray(exeContext)
? { errors: exeContext }
: executeSubscription(exeContext),
);
}).catch(reportGraphQLError);
// Return early errors if execution context failed.
if (Array.isArray(exeContext)) {
return { errors: exeContext };
}

const eventStream = await executeSubscription(exeContext);

// Assert field returned an event stream, otherwise yield an error.
if (!isAsyncIterable(eventStream)) {
throw new Error(
'Subscription field must return Async Iterable. ' +
`Received: ${inspect(eventStream)}.`,
);
}

return eventStream;
} catch (error) {
// If it GraphQLError, report it as an ExecutionResult, containing only errors and no data.
// Otherwise treat the error as a system-class error and re-throw it.
if (error instanceof GraphQLError) {
return { errors: [error] };
}
throw error;
}
}

function executeSubscription(
async function executeSubscription(
exeContext: ExecutionContext,
): Promise<AsyncIterable<mixed>> {
): Promise<mixed> {
const { schema, operation, variableValues, rootValue } = exeContext;
const type = getOperationRootType(schema, operation);
const fields = collectFields(
Expand All @@ -216,8 +220,7 @@ function executeSubscription(
const path = addPath(undefined, responseName, type.name);
const info = buildResolveInfo(exeContext, fieldDef, fieldNodes, type, path);

// Coerce to Promise for easier error handling and consistent return type.
return new Promise((resolveResult) => {
try {
// Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
// It differs from "ResolveFieldValue" due to providing a different `resolveFn`.

Expand All @@ -233,24 +236,13 @@ function executeSubscription(
// Call the `subscribe()` resolver or the default resolver to produce an
// AsyncIterable yielding raw payloads.
const resolveFn = fieldDef.subscribe ?? exeContext.fieldResolver;
resolveResult(resolveFn(rootValue, args, contextValue, info));
}).then(
(eventStream) => {
if (eventStream instanceof Error) {
throw locatedError(eventStream, fieldNodes, pathToArray(path));
}

// Assert field returned an event stream, otherwise yield an error.
if (!isAsyncIterable(eventStream)) {
throw new Error(
'Subscription field must return Async Iterable. ' +
`Received: ${inspect(eventStream)}.`,
);
}
return eventStream;
},
(error) => {
throw locatedError(error, fieldNodes, pathToArray(path));
},
);
const eventStream = await resolveFn(rootValue, args, contextValue, info);

if (eventStream instanceof Error) {
throw eventStream;
}
return eventStream;
} catch (error) {
throw locatedError(error, fieldNodes, pathToArray(path));
}
}