Skip to content

Commit

Permalink
Merge pull request #3537 from graphql-java/19.x-backport-3525-max-res…
Browse files Browse the repository at this point in the history
…ult-nodes

19.x Backport PR 3525 max result nodes
  • Loading branch information
dondonz committed Mar 19, 2024
2 parents 84d4e39 + 90c1e51 commit c0b905c
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 5 deletions.
7 changes: 7 additions & 0 deletions src/main/java/graphql/GraphQLContext.java
Expand Up @@ -224,6 +224,13 @@ public static GraphQLContext of(Consumer<GraphQLContext.Builder> contextBuilderC
return of(builder.map);
}

/**
* @return a new and empty graphql context object
*/
public static GraphQLContext getDefault() {
return GraphQLContext.newContext().build();
}

/**
* Creates a new GraphqlContext builder
*
Expand Down
1 change: 1 addition & 0 deletions src/main/java/graphql/execution/Execution.java
Expand Up @@ -97,6 +97,7 @@ public CompletableFuture<ExecutionResult> execute(Document document, GraphQLSche
.executionInput(executionInput)
.build();

executionContext.getGraphQLContext().put(ResultNodesInfo.RESULT_NODES_INFO, executionContext.getResultNodesInfo());

InstrumentationExecutionParameters parameters = new InstrumentationExecutionParameters(
executionInput, graphQLSchema, instrumentationState
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/graphql/execution/ExecutionContext.java
Expand Up @@ -56,6 +56,7 @@ public class ExecutionContext {
private final ValueUnboxer valueUnboxer;
private final ExecutionInput executionInput;
private final Supplier<ExecutableNormalizedOperation> queryTree;
private final ResultNodesInfo resultNodesInfo = new ResultNodesInfo();

ExecutionContext(ExecutionContextBuilder builder) {
this.graphQLSchema = builder.graphQLSchema;
Expand Down Expand Up @@ -287,4 +288,8 @@ public ExecutionContext transform(Consumer<ExecutionContextBuilder> builderConsu
builderConsumer.accept(builder);
return builder.build();
}

public ResultNodesInfo getResultNodesInfo() {
return resultNodesInfo;
}
}
29 changes: 27 additions & 2 deletions src/main/java/graphql/execution/ExecutionStrategy.java
Expand Up @@ -60,6 +60,7 @@
import static graphql.execution.FieldValueInfo.CompleteValueType.NULL;
import static graphql.execution.FieldValueInfo.CompleteValueType.OBJECT;
import static graphql.execution.FieldValueInfo.CompleteValueType.SCALAR;
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES;
import static graphql.execution.instrumentation.SimpleInstrumentationContext.nonNullCtx;
import static graphql.schema.DataFetchingEnvironmentImpl.newDataFetchingEnvironment;
import static graphql.schema.GraphQLTypeUtil.isEnum;
Expand Down Expand Up @@ -237,8 +238,23 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
MergedField field = parameters.getField();
GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType();
GraphQLFieldDefinition fieldDef = getFieldDef(executionContext.getGraphQLSchema(), parentType, field.getSingleField());
return fetchField(fieldDef, executionContext, parameters);
}

GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry();
private CompletableFuture<FetchedValue> fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters) {

int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();

Integer maxNodes;
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
if (resultNodesCount > maxNodes) {
executionContext.getResultNodesInfo().maxResultNodesExceeded();
return CompletableFuture.completedFuture(new FetchedValue(null, null, ImmutableKit.emptyList(), null));
}
}

MergedField field = parameters.getField();
GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType();
GraphQLOutputType fieldType = fieldDef.getType();

// if the DF (like PropertyDataFetcher) does not use the arguments of execution step info then dont build any
Expand All @@ -252,7 +268,6 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
DataFetchingFieldSelectionSet fieldCollector = DataFetchingFieldSelectionSetImpl.newCollector(executionContext.getGraphQLSchema(), fieldType, normalizedFieldSupplier);
QueryDirectives queryDirectives = new QueryDirectivesImpl(field, executionContext.getGraphQLSchema(), executionContext.getVariables());


DataFetchingEnvironment environment = newDataFetchingEnvironment(executionContext)
.source(parameters.getSource())
.localContext(parameters.getLocalContext())
Expand All @@ -266,6 +281,7 @@ protected CompletableFuture<FetchedValue> fetchField(ExecutionContext executionC
.queryDirectives(queryDirectives)
.build();

GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry();
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parentType, fieldDef);

Instrumentation instrumentation = executionContext.getInstrumentation();
Expand Down Expand Up @@ -528,6 +544,15 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext,
List<FieldValueInfo> fieldValueInfos = new ArrayList<>(size.orElse(1));
int index = 0;
for (Object item : iterableValues) {
int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount();
Integer maxNodes;
if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) {
if (resultNodesCount > maxNodes) {
executionContext.getResultNodesInfo().maxResultNodesExceeded();
return new FieldValueInfo(NULL, completedFuture(ExecutionResultImpl.newExecutionResult().build()), fieldValueInfos);
}
}

ResultPath indexedPath = parameters.getPath().segment(index);

ExecutionStepInfo stepInfoForListElement = executionStepInfoFactory.newExecutionStepInfoForListElement(executionStepInfo, index);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/execution/FetchedValue.java
Expand Up @@ -19,7 +19,7 @@ public class FetchedValue {
private final Object localContext;
private final ImmutableList<GraphQLError> errors;

private FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList<GraphQLError> errors, Object localContext) {
FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList<GraphQLError> errors, Object localContext) {
this.fetchedValue = fetchedValue;
this.rawFetchedValue = rawFetchedValue;
this.errors = errors;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/graphql/execution/FieldValueInfo.java
Expand Up @@ -25,7 +25,7 @@ public enum CompleteValueType {
private final CompletableFuture<ExecutionResult> fieldValue;
private final List<FieldValueInfo> fieldValueInfos;

private FieldValueInfo(CompleteValueType completeValueType, CompletableFuture<ExecutionResult> fieldValue, List<FieldValueInfo> fieldValueInfos) {
FieldValueInfo(CompleteValueType completeValueType, CompletableFuture<ExecutionResult> fieldValue, List<FieldValueInfo> fieldValueInfos) {
assertNotNull(fieldValueInfos, () -> "fieldValueInfos can't be null");
this.completeValueType = completeValueType;
this.fieldValue = fieldValue;
Expand Down
55 changes: 55 additions & 0 deletions src/main/java/graphql/execution/ResultNodesInfo.java
@@ -0,0 +1,55 @@
package graphql.execution;

import graphql.Internal;
import graphql.PublicApi;

import java.util.concurrent.atomic.AtomicInteger;

/**
* This class is used to track the number of result nodes that have been created during execution.
* After each execution the GraphQLContext contains a ResultNodeInfo object under the key {@link ResultNodesInfo#RESULT_NODES_INFO}
* <p>
* The number of result can be limited (and should be for security reasons) by setting the maximum number of result nodes
* in the GraphQLContext under the key {@link ResultNodesInfo#MAX_RESULT_NODES} to an Integer
* </p>
*/
@PublicApi
public class ResultNodesInfo {

public static final String MAX_RESULT_NODES = "__MAX_RESULT_NODES";
public static final String RESULT_NODES_INFO = "__RESULT_NODES_INFO";

private volatile boolean maxResultNodesExceeded = false;
private final AtomicInteger resultNodesCount = new AtomicInteger(0);

@Internal
public int incrementAndGetResultNodesCount() {
return resultNodesCount.incrementAndGet();
}

@Internal
public void maxResultNodesExceeded() {
this.maxResultNodesExceeded = true;
}

/**
* The number of result nodes created.
* Note: this can be higher than max result nodes because
* a each node that exceeds the number of max nodes is set to null,
* but still is a result node (with value null)
*
* @return number of result nodes created
*/
public int getResultNodesCount() {
return resultNodesCount.get();
}

/**
* If the number of result nodes has exceeded the maximum allowed numbers.
*
* @return true if the number of result nodes has exceeded the maximum allowed numbers
*/
public boolean isMaxResultNodesExceeded() {
return maxResultNodesExceeded;
}
}
141 changes: 141 additions & 0 deletions src/test/groovy/graphql/GraphQLTest.groovy
Expand Up @@ -14,6 +14,7 @@ import graphql.execution.ExecutionId
import graphql.execution.ExecutionIdProvider
import graphql.execution.ExecutionStrategyParameters
import graphql.execution.MissingRootTypeException
import graphql.execution.ResultNodesInfo
import graphql.execution.SubscriptionExecutionStrategy
import graphql.execution.ValueUnboxer
import graphql.execution.instrumentation.ChainedInstrumentation
Expand Down Expand Up @@ -47,6 +48,7 @@ import static graphql.ExecutionInput.Builder
import static graphql.ExecutionInput.newExecutionInput
import static graphql.Scalars.GraphQLInt
import static graphql.Scalars.GraphQLString
import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES
import static graphql.schema.GraphQLArgument.newArgument
import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition
import static graphql.schema.GraphQLInputObjectField.newInputObjectField
Expand Down Expand Up @@ -1363,4 +1365,143 @@ many lines''']
then:
! er.errors.isEmpty()
}
def "max result nodes not breached"() {
given:
def sdl = '''
type Query {
hello: String
}
'''
def df = { env -> "world" } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello h1: hello h2: hello h3: hello } "
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 4);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
!rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: "world", h1: "world", h2: "world", h3: "world"]
}
def "max result nodes breached"() {
given:
def sdl = '''
type Query {
hello: String
}
'''
def df = { env -> "world" } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello h1: hello h2: hello h3: hello } "
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: "world", h1: "world", h2: "world", h3: null]
}
def "max result nodes breached with list"() {
given:
def sdl = '''
type Query {
hello: [String]
}
'''
def df = { env -> ["w1", "w2", "w3"] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello}"
def ei = newExecutionInput(query).build()
ei.getGraphQLContext().put(MAX_RESULT_NODES, 3);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.maxResultNodesExceeded
rni.resultNodesCount == 4
er.data == [hello: null]
}
def "max result nodes breached with list 2"() {
given:
def sdl = '''
type Query {
hello: [Foo]
}
type Foo {
name: String
}
'''
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello {name}}"
def ei = newExecutionInput(query).build()
// we have 7 result nodes overall
ei.getGraphQLContext().put(MAX_RESULT_NODES, 6);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
rni.resultNodesCount == 7
rni.maxResultNodesExceeded
er.data == [hello: [[name: "w1"], [name: "w2"], [name: null]]]
}
def "max result nodes not breached with list"() {
given:
def sdl = '''
type Query {
hello: [Foo]
}
type Foo {
name: String
}
'''
def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher
def fetchers = ["Query": ["hello": df]]
def schema = TestUtil.schema(sdl, fetchers)
def graphQL = GraphQL.newGraphQL(schema).build()
def query = "{ hello {name}}"
def ei = newExecutionInput(query).build()
// we have 7 result nodes overall
ei.getGraphQLContext().put(MAX_RESULT_NODES, 7);
when:
def er = graphQL.execute(ei)
def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo
then:
!rni.maxResultNodesExceeded
rni.resultNodesCount == 7
er.data == [hello: [[name: "w1"], [name: "w2"], [name: "w3"]]]
}
}
Expand Up @@ -2,6 +2,7 @@ package graphql.execution

import graphql.ErrorType
import graphql.ExecutionResult
import graphql.GraphQLContext
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.SimpleInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
Expand All @@ -28,6 +29,8 @@ import static org.awaitility.Awaitility.await

class AsyncExecutionStrategyTest extends Specification {

def graphqlContextMock = Mock(GraphQLContext)

GraphQLSchema schema(DataFetcher dataFetcher1, DataFetcher dataFetcher2) {
GraphQLFieldDefinition.Builder fieldDefinition = newFieldDefinition()
.name("hello")
Expand Down Expand Up @@ -82,6 +85,7 @@ class AsyncExecutionStrategyTest extends Specification {
.operationDefinition(operation)
.instrumentation(SimpleInstrumentation.INSTANCE)
.valueUnboxer(ValueUnboxer.DEFAULT)
.graphQLContext(graphqlContextMock)
.build()
ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters
.newParameters()
Expand Down Expand Up @@ -121,6 +125,7 @@ class AsyncExecutionStrategyTest extends Specification {
.operationDefinition(operation)
.valueUnboxer(ValueUnboxer.DEFAULT)
.instrumentation(SimpleInstrumentation.INSTANCE)
.graphQLContext(graphqlContextMock)
.build()
ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters
.newParameters()
Expand Down Expand Up @@ -162,6 +167,7 @@ class AsyncExecutionStrategyTest extends Specification {
.operationDefinition(operation)
.valueUnboxer(ValueUnboxer.DEFAULT)
.instrumentation(SimpleInstrumentation.INSTANCE)
.graphQLContext(graphqlContextMock)
.build()
ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters
.newParameters()
Expand Down Expand Up @@ -202,6 +208,7 @@ class AsyncExecutionStrategyTest extends Specification {
.operationDefinition(operation)
.instrumentation(SimpleInstrumentation.INSTANCE)
.valueUnboxer(ValueUnboxer.DEFAULT)
.graphQLContext(graphqlContextMock)
.build()
ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters
.newParameters()
Expand Down Expand Up @@ -262,6 +269,7 @@ class AsyncExecutionStrategyTest extends Specification {
}
}
})
.graphQLContext(graphqlContextMock)
.build()
ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters
.newParameters()
Expand Down

0 comments on commit c0b905c

Please sign in to comment.