Skip to content

Commit

Permalink
Merge pull request #3551 from jbellenger/jbellenger-validate-dir-args
Browse files Browse the repository at this point in the history
validate non-nullable directive args
  • Loading branch information
bbakerman committed Apr 16, 2024
2 parents c4df085 + b906880 commit 0eec91e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 18 deletions.
Expand Up @@ -6,11 +6,14 @@
import graphql.execution.ValuesResolver;
import graphql.language.Value;
import graphql.schema.CoercingParseValueException;
import graphql.schema.GraphQLAppliedDirective;
import graphql.schema.GraphQLAppliedDirectiveArgument;
import graphql.schema.GraphQLArgument;
import graphql.schema.GraphQLDirective;
import graphql.schema.GraphQLInputType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLTypeUtil;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.InputValueWithState;
import graphql.util.TraversalControl;
Expand All @@ -32,29 +35,56 @@ public TraversalControl visitGraphQLDirective(GraphQLDirective directive, Traver
// if there is no parent it means it is just a directive definition and not an applied directive
if (context.getParentNode() != null) {
for (GraphQLArgument graphQLArgument : directive.getArguments()) {
checkArgument(directive, graphQLArgument, context);
checkArgument(
directive.getName(),
graphQLArgument.getName(),
graphQLArgument.getArgumentValue(),
graphQLArgument.getType(),
context
);
}
}
return TraversalControl.CONTINUE;
}

private void checkArgument(GraphQLDirective directive, GraphQLArgument argument, TraverserContext<GraphQLSchemaElement> context) {
if (!argument.hasSetValue()) {
return;
@Override
public TraversalControl visitGraphQLAppliedDirective(GraphQLAppliedDirective directive, TraverserContext<GraphQLSchemaElement> context) {
// if there is no parent it means it is just a directive definition and not an applied directive
if (context.getParentNode() != null) {
for (GraphQLAppliedDirectiveArgument graphQLArgument : directive.getArguments()) {
checkArgument(
directive.getName(),
graphQLArgument.getName(),
graphQLArgument.getArgumentValue(),
graphQLArgument.getType(),
context
);
}
}
return TraversalControl.CONTINUE;
}

private void checkArgument(
String directiveName,
String argumentName,
InputValueWithState argumentValue,
GraphQLInputType argumentType,
TraverserContext<GraphQLSchemaElement> context
) {
GraphQLSchema schema = context.getVarFromParents(GraphQLSchema.class);
SchemaValidationErrorCollector errorCollector = context.getVarFromParents(SchemaValidationErrorCollector.class);
InputValueWithState argumentValue = argument.getArgumentValue();
boolean invalid = false;
if (argumentValue.isLiteral() &&
!validationUtil.isValidLiteralValue((Value<?>) argumentValue.getValue(), argument.getType(), schema, GraphQLContext.getDefault(), Locale.getDefault())) {
!validationUtil.isValidLiteralValue((Value<?>) argumentValue.getValue(), argumentType, schema, GraphQLContext.getDefault(), Locale.getDefault())) {
invalid = true;
} else if (argumentValue.isExternal() &&
!isValidExternalValue(schema, argumentValue.getValue(), argument.getType(), GraphQLContext.getDefault(), Locale.getDefault())) {
!isValidExternalValue(schema, argumentValue.getValue(), argumentType, GraphQLContext.getDefault(), Locale.getDefault())) {
invalid = true;
} else if (argumentValue.isNotSet() && GraphQLTypeUtil.isNonNull(argumentType)) {
invalid = true;
}
if (invalid) {
String message = format("Invalid argument '%s' for applied directive of name '%s'", argument.getName(), directive.getName());
String message = format("Invalid argument '%s' for applied directive of name '%s'", argumentName, directiveName);
errorCollector.addError(new SchemaValidationError(SchemaValidationErrorType.InvalidAppliedDirectiveArgument, message));
}
}
Expand Down
50 changes: 40 additions & 10 deletions src/test/groovy/graphql/schema/GraphQLArgumentTest.groovy
Expand Up @@ -196,23 +196,53 @@ class GraphQLArgumentTest extends Specification {
resolvedDefaultValue == null
}

def "Applied schema directives arguments are validated for programmatic schemas"() {
def "schema directive arguments are validated for programmatic schemas"() {
given:
def arg = newArgument().name("arg").type(GraphQLInt).valueProgrammatic(ImmutableKit.emptyMap()).build() // Retain for test coverage
def directive = mkDirective("cached", ARGUMENT_DEFINITION, arg)
def field = newFieldDefinition()
.name("hello")
.type(GraphQLString)
.argument(arg)
.withDirective(directive)
.build()
.name("hello")
.type(GraphQLString)
.argument(arg)
.withDirective(directive)
.build()
when:
newSchema().query(
newSchema()
.query(
newObject()
.name("Query")
.field(field)
.build())
.name("Query")
.field(field)
.build()
)
.additionalDirective(directive)
.build()
then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid argument 'arg' for applied directive of name 'cached'")
}

def "applied directive arguments are validated for programmatic schemas"() {
given:
def arg = newArgument()
.name("arg")
.type(GraphQLNonNull.nonNull(GraphQLInt))
.build()
def directive = mkDirective("cached", ARGUMENT_DEFINITION, arg)
def field = newFieldDefinition()
.name("hello")
.type(GraphQLString)
.withAppliedDirective(directive.toAppliedDirective())
.build()
when:
newSchema()
.query(
newObject()
.name("Query")
.field(field)
.build()
)
.additionalDirective(directive)
.build()
then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid argument 'arg' for applied directive of name 'cached'")
Expand Down

0 comments on commit 0eec91e

Please sign in to comment.