Skip to content
This repository has been archived by the owner on Oct 26, 2020. It is now read-only.

Support for repeatable directives #5

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/scala/sangria/ast/QueryAst.scala
Expand Up @@ -457,6 +457,7 @@ case class DirectiveDefinition(
arguments: Vector[InputValueDefinition],
locations: Vector[DirectiveLocation],
description: Option[StringValue] = None,
repeatable: Boolean = false,
comments: Vector[Comment] = Vector.empty,
location: Option[AstLocation] = None) extends TypeSystemDefinition with WithDescription

Expand Down Expand Up @@ -911,7 +912,7 @@ object AstVisitor {
tc.foreach(c => loop(c))
breakOrSkip(onLeave(n))
}
case n @ DirectiveDefinition(_, args, locations, description, comment, _) =>
case n @ DirectiveDefinition(_, args, locations, description, _, comment, _) =>
if (breakOrSkip(onEnter(n))) {
args.foreach(d => loop(d))
locations.foreach(d => loop(d))
Expand Down
Expand Up @@ -82,7 +82,8 @@ object IntrospectionParser {
name = mapStringField(directive, "name", path),
description = mapStringFieldOpt(directive, "description"),
locations = um.getListValue(mapField(directive, "locations")).map(v => DirectiveLocation.fromString(stringValue(v, path :+ "locations"))).toSet,
args = mapFieldOpt(directive, "args") map um.getListValue getOrElse Vector.empty map (arg => parseInputValue(arg, path :+ "args")))
args = mapFieldOpt(directive, "args") map um.getListValue getOrElse Vector.empty map (arg => parseInputValue(arg, path :+ "args")),
repeatable = mapBooleanFieldOpt(directive, "isRepeatable") getOrElse false)

private def parseType[In : InputUnmarshaller](tpe: In, path: Vector[String]) =
mapStringField(tpe, "kind", path) match {
Expand Down Expand Up @@ -148,11 +149,14 @@ object IntrospectionParser {
private def mapBooleanField[In : InputUnmarshaller](map: In, name: String, path: Vector[String] = Vector.empty): Boolean =
booleanValue(mapField(map, name, path), path :+ name)

private def mapBooleanFieldOpt[In : InputUnmarshaller](map: In, name: String, path: Vector[String] = Vector.empty): Option[Boolean] =
mapFieldOpt(map, name) filter um.isDefined map (booleanValue(_, path :+ name))

private def mapFieldOpt[In : InputUnmarshaller](map: In, name: String): Option[In] =
um.getMapValue(map, name) filter um.isDefined

private def mapStringFieldOpt[In : InputUnmarshaller](map: In, name: String, path: Vector[String] = Vector.empty): Option[String] =
mapFieldOpt(map, name) filter um.isDefined map (s => stringValue(s, path :+ name) )
mapFieldOpt(map, name) filter um.isDefined map (stringValue(_, path :+ name))

private def um[T: InputUnmarshaller] = implicitly[InputUnmarshaller[T]]

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/sangria/introspection/model.scala
Expand Up @@ -116,4 +116,5 @@ case class IntrospectionDirective(
name: String,
description: Option[String],
locations: Set[DirectiveLocation.Value],
args: Seq[IntrospectionInputValue])
args: Seq[IntrospectionInputValue],
repeatable: Boolean)
11 changes: 7 additions & 4 deletions src/main/scala/sangria/introspection/package.scala
Expand Up @@ -254,7 +254,9 @@ package object introspection {
Field("name", StringType, resolve = _.value.name),
Field("description", OptionType(StringType), resolve = _.value.description),
Field("locations", ListType(__DirectiveLocation), resolve = _.value.locations.toVector.sorted),
Field("args", ListType(__InputValue), resolve = _.value.arguments)))
Field("args", ListType(__InputValue), resolve = _.value.arguments),
Field("isRepeatable", BooleanType, Some("Permits using the directive multiple times at the same location."),
resolve = _.value.repeatable)))

val __Schema = ObjectType(
name = "__Schema",
Expand Down Expand Up @@ -309,10 +311,10 @@ package object introspection {

def introspectionQuery: ast.Document = introspectionQuery()

def introspectionQuery(schemaDescription: Boolean = true): ast.Document =
QueryParser.parse(introspectionQueryString(schemaDescription))
def introspectionQuery(schemaDescription: Boolean = true, directiveRepeatableFlag: Boolean = true): ast.Document =
QueryParser.parse(introspectionQueryString(schemaDescription, directiveRepeatableFlag))

def introspectionQueryString(schemaDescription: Boolean = true): String =
def introspectionQueryString(schemaDescription: Boolean = true, directiveRepeatableFlag: Boolean = true): String =
s"""query IntrospectionQuery {
| __schema {
| queryType { name }
Expand All @@ -328,6 +330,7 @@ package object introspection {
| args {
| ...InputValue
| }
| ${if (directiveRepeatableFlag) "isRepeatable" else ""}
| }
| ${if (schemaDescription) "description" else ""}
| }
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/sangria/macros/AstLiftable.scala
Expand Up @@ -76,8 +76,8 @@ trait AstLiftable {
case FragmentDefinition(n, t, d, s, v, c, tc, p) =>
q"_root_.sangria.ast.FragmentDefinition($n, $t, $d, $s, $v, $c, $tc, $p)"

case DirectiveDefinition(n, a, l, desc, c, p) =>
q"_root_.sangria.ast.DirectiveDefinition($n, $a, $l, $desc, $c, $p)"
case DirectiveDefinition(n, a, l, desc, r, c, p) =>
q"_root_.sangria.ast.DirectiveDefinition($n, $a, $l, $desc, $r, $c, $p)"
case SchemaDefinition(o, d, desc, c, tc, p) =>
q"_root_.sangria.ast.SchemaDefinition($o, $d, $desc, $c, $tc, $p)"

Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/sangria/parser/QueryParser.scala
Expand Up @@ -320,9 +320,11 @@ trait TypeSystemDefinitions { this: Parser with Tokens with Ignored with Directi
wsNoComment('{') ~ (test(legacyEmptyFields) ~ InputValueDefinition.* | InputValueDefinition.+) ~ Comments ~ wsNoComment('}') ~> (_ -> _)
}

def repeatable = rule { capture(Keyword("repeatable")).? ~> (_.isDefined)}

def DirectiveDefinition = rule {
Description ~ Comments ~ trackPos ~ directive ~ '@' ~ NameStrict ~ (ArgumentsDefinition.? ~> (_ getOrElse Vector.empty)) ~ on ~ DirectiveLocations ~> (
(descr, comment, location, name, args, locations) => ast.DirectiveDefinition(name, args, locations, descr, comment, location))
Description ~ Comments ~ trackPos ~ directive ~ '@' ~ NameStrict ~ (ArgumentsDefinition.? ~> (_ getOrElse Vector.empty)) ~ repeatable ~ on ~ DirectiveLocations ~> (
(descr, comment, location, name, args, rep, locations) => ast.DirectiveDefinition(name, args, locations, descr, rep, comment, location))
}

def DirectiveLocations = rule { ws('|').? ~ DirectiveLocation.+(wsNoComment('|')) ~> (_.toVector) }
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/sangria/renderer/QueryRenderer.scala
Expand Up @@ -557,7 +557,7 @@ object QueryRenderer {
renderDirs(dirs, config, indent, frontSep = true) +
renderOperationTypeDefinitions(ops, ext, indent, config, frontSep = true)

case dd @ DirectiveDefinition(name, args, locations, description, _, _) =>
case dd @ DirectiveDefinition(name, args, locations, description, rep, _, _) =>
val locsRendered = locations.zipWithIndex map { case (l, idx) =>
(if (idx != 0 && shouldRenderComment(l, None, config)) config.lineBreak else "") +
(if (shouldRenderComment(l, None, config)) config.lineBreak else if (idx != 0) config.separator else "") +
Expand All @@ -568,7 +568,9 @@ object QueryRenderer {
renderComment(dd, description orElse prev, indent, config) +
indent.str + "directive" + config.separator + "@" + name +
renderInputValueDefs(args, indent, config) + (if (args.isEmpty) config.mandatorySeparator else "") +
"on" + (if (shouldRenderComment(locations.head, None, config)) "" else config.mandatorySeparator) +
(if (rep) "repeatable" + config.mandatorySeparator else "") +
"on" +
(if (shouldRenderComment(locations.head, None, config)) "" else config.mandatorySeparator) +
locsRendered.mkString(config.separator + "|")

case dl @ DirectiveLocation(name, _, _) =>
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/sangria/renderer/SchemaRenderer.scala
Expand Up @@ -214,10 +214,10 @@ object SchemaRenderer {
ast.DirectiveLocation(__DirectiveLocation.byValue(loc).name)

def renderDirective(dir: Directive) =
ast.DirectiveDefinition(dir.name, renderArgs(dir.arguments), dir.locations.toVector.map(renderDirectiveLocation).sortBy(_.name), renderDescription(dir.description))
ast.DirectiveDefinition(dir.name, renderArgs(dir.arguments), dir.locations.toVector.map(renderDirectiveLocation).sortBy(_.name), renderDescription(dir.description), dir.repeatable)

def renderDirective(dir: IntrospectionDirective) =
ast.DirectiveDefinition(dir.name, renderArgsI(dir.args), dir.locations.toVector.map(renderDirectiveLocation).sortBy(_.name), renderDescription(dir.description))
ast.DirectiveDefinition(dir.name, renderArgsI(dir.args), dir.locations.toVector.map(renderDirectiveLocation).sortBy(_.name), renderDescription(dir.description), dir.repeatable)

def schemaAstFromIntrospection(introspectionSchema: IntrospectionSchema, filter: SchemaFilter = SchemaFilter.default): ast.Document = {
val schemaDef = if (filter.renderSchema) renderSchemaDefinition(introspectionSchema) else None
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/sangria/schema/AstSchemaBuilder.scala
Expand Up @@ -674,6 +674,7 @@ class DefaultAstSchemaBuilder[Ctx] extends AstSchemaBuilder[Ctx] {
description = directiveDescription(definition),
locations = locations,
arguments = arguments,
repeatable = definition.repeatable,
shouldInclude = directiveShouldInclude(definition)))

def transformInputObjectType[T](
Expand Down
Expand Up @@ -258,6 +258,7 @@ class DefaultIntrospectionSchemaBuilder[Ctx] extends IntrospectionSchemaBuilder[
description = directiveDescription(definition),
locations = definition.locations,
arguments = arguments,
repeatable = definition.repeatable,
shouldInclude = directiveShouldInclude(definition)))

def objectTypeInstanceCheck(definition: IntrospectionObjectType): Option[(Any, Class[_]) => Boolean] =
Expand Down
20 changes: 20 additions & 0 deletions src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala
Expand Up @@ -102,6 +102,26 @@ class ResolverBasedAstSchemaBuilder[Ctx](val resolvers: Seq[AstSchemaResolver[Ct
case r @ AnyFieldResolver(fn) if fn.isDefinedAt(origin) => r
}

override def buildSchema(
definition: Option[ast.SchemaDefinition],
extensions: List[ast.SchemaExtensionDefinition],
queryType: ObjectType[Ctx, Any],
mutationType: Option[ObjectType[Ctx, Any]],
subscriptionType: Option[ObjectType[Ctx, Any]],
additionalTypes: List[Type with Named],
directives: List[Directive],
mat: AstSchemaMaterializer[Ctx]) =
Schema[Ctx, Any](
query = queryType,
mutation = mutationType,
subscription = subscriptionType,
additionalTypes = additionalTypes,
description = definition.flatMap(_.description.map(_.value)),
directives = directives,
astDirectives = definition.fold(Vector.empty[ast.Directive])(_.directives) ++ extensions.flatMap(_.directives),
astNodes = Vector(mat.document) ++ extensions ++ definition.toVector,
validationRules = SchemaValidationRule.default :+ new ResolvedDirectiveValidationRule(this.directives.filterNot(_.repeatable).map(_.name).toSet))

override def resolveField(
origin: MatOrigin,
typeDefinition: Either[ast.TypeDefinition, ObjectLikeType[Ctx, _]],
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/sangria/schema/Schema.scala
Expand Up @@ -754,6 +754,7 @@ case class Directive(
description: Option[String] = None,
arguments: List[Argument[_]] = Nil,
locations: Set[DirectiveLocation.Value] = Set.empty,
repeatable: Boolean = false,
shouldInclude: DirectiveContext => Boolean = _ => true) extends HasArguments with Named {
def rename(newName: String) = copy(name = newName).asInstanceOf[this.type]
def toAst: ast.DirectiveDefinition = SchemaRenderer.renderDirective(this)
Expand Down
11 changes: 10 additions & 1 deletion src/main/scala/sangria/schema/SchemaComparator.scala
Expand Up @@ -59,6 +59,12 @@ object SchemaComparator {
}

private def findInDirective(oldDir: Directive, newDir: Directive): Vector[SchemaChange] = {
val repeatableChanged =
if (oldDir.repeatable != newDir.repeatable)
Vector(SchemaChange.DirectiveRepeatableChanged(newDir, oldDir.repeatable, newDir.repeatable, !newDir.repeatable))
else
Vector.empty

val locationChanges = findInDirectiveLocations(oldDir, newDir)
val fieldChanges = findInArgs(oldDir.arguments, newDir.arguments,
added = SchemaChange.DirectiveArgumentAdded(newDir, _, _),
Expand All @@ -69,7 +75,7 @@ object SchemaComparator {
dirAdded = SchemaChange.DirectiveArgumentAstDirectiveAdded(newDir, _, _),
dirRemoved = SchemaChange.DirectiveArgumentAstDirectiveRemoved(newDir, _, _))

locationChanges ++ fieldChanges
repeatableChanged ++ locationChanges ++ fieldChanges
}

private def findInDirectiveLocations(oldDir: Directive, newDir: Directive): Vector[SchemaChange] = {
Expand Down Expand Up @@ -659,6 +665,9 @@ object SchemaChange {
case class DirectiveArgumentAdded(directive: Directive, argument: Argument[_], breaking: Boolean)
extends AbstractChange(s"Argument `${argument.name}` was added to `${directive.name}` directive", breaking)

case class DirectiveRepeatableChanged(directive: Directive, oldRepeatable: Boolean, newRepeatable: Boolean, breaking: Boolean)
extends AbstractChange(if (newRepeatable) s"Directive `${directive.name}` was made repeatable per location" else s"Directive `${directive.name}` was made unique per location", breaking)

case class InputFieldTypeChanged(tpe: InputObjectType[_], field: InputField[_], breaking: Boolean, oldFiledType: InputType[_], newFieldType: InputType[_])
extends AbstractChange(s"`${tpe.name}.${field.name}` input field type changed from `${SchemaRenderer.renderTypeName(oldFiledType)}` to `${SchemaRenderer.renderTypeName(newFieldType)}`", breaking) with TypeChange

Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/sangria/schema/SchemaValidationRule.scala
Expand Up @@ -493,5 +493,37 @@ class FullSchemaTraversalValidationRule(validators: SchemaElementValidator*) ext
def validName(name: String): Boolean = !reservedNames.contains(name)
}

/**
* Validates uniqueness of directives on types and the schema definition.
*
* It is not fully covered by `UniqueDirectivesPerLocation` since it onl looks at one AST node at a time,
* so it does not cover type + type extension scenario.
*/
class ResolvedDirectiveValidationRule(knownUniqueDirectives: Set[String]) extends SchemaValidationRule {
def validate[Ctx, Val](schema: Schema[Ctx, Val]): List[Violation] = {
val uniqueDirectives = knownUniqueDirectives ++ schema.directives.filterNot(_.repeatable).map(_.name)
val sourceMapper = SchemaElementValidator.sourceMapper(schema)

val schemaViolations = validateUniqueDirectives(schema, uniqueDirectives, sourceMapper)

val typeViolations =
schema.typeList.collect {
case withDirs: HasAstInfo => validateUniqueDirectives(withDirs, uniqueDirectives, sourceMapper)
}

schemaViolations.toList ++ typeViolations.flatten
}

private def validateUniqueDirectives(withDirs: HasAstInfo, uniqueDirectives: Set[String], sourceMapper: Option[SourceMapper]) = {
val duplicates = withDirs.astDirectives
.filter(d => uniqueDirectives.contains(d.name))
.groupBy(_.name)
.filter(_._2.size > 1)
.toVector

duplicates.map{case (dirName, dups) => DuplicateDirectiveViolation(dirName, sourceMapper, dups.flatMap(_.location).toList)}
}
}

case class SchemaValidationException(violations: Vector[Violation], eh: ExceptionHandler = ExceptionHandler.empty) extends ExecutionError(
s"Schema does not pass validation. Violations:\n\n${violations map (_.errorMessage) mkString "\n\n"}", eh) with WithViolations with QueryAnalysisError
Expand Up @@ -14,6 +14,8 @@ import scala.collection.mutable.{Map => MutableMap}
*/
class UniqueDirectivesPerLocation extends ValidationRule {
override def visitor(ctx: ValidationContext) = new AstValidatingVisitor {
val repeatableDirectives = ctx.schema.directivesByName.mapValues(d => d.repeatable)

override val onEnter: ValidationVisit = {
// Many different AST nodes may contain directives. Rather than listing
// them all, just listen for entering any node, and check to see if it
Expand All @@ -22,11 +24,13 @@ class UniqueDirectivesPerLocation extends ValidationRule {
val knownDirectives = MutableMap[String, ast.Directive]()

val errors = node.directives.foldLeft(Vector.empty[Violation]) {
case (errors, d) if knownDirectives contains d.name =>
errors :+ DuplicateDirectiveViolation(d.name, ctx.sourceMapper, knownDirectives(d.name).location.toList ++ d.location.toList )
case (errors, d) =>
case (es, d) if repeatableDirectives.getOrElse(d.name, true) =>
es
case (es, d) if knownDirectives contains d.name =>
es :+ DuplicateDirectiveViolation(d.name, ctx.sourceMapper, knownDirectives(d.name).location.toList ++ d.location.toList )
case (es, d) =>
knownDirectives(d.name) = d
errors
es
}

if (errors.nonEmpty) Left(errors)
Expand Down
2 changes: 2 additions & 0 deletions src/test/resources/queries/schema-kitchen-sink-pretty.graphql
Expand Up @@ -84,4 +84,6 @@ extend type Foo @onType
"cool skip"
directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT

directive @myRepeatableDir(name: String!) repeatable on OBJECT | INTERFACE

directive @include(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT
4 changes: 4 additions & 0 deletions src/test/resources/queries/schema-kitchen-sink.graphql
Expand Up @@ -83,6 +83,10 @@ extend type Foo @onType
"cool skip"
directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT

directive @myRepeatableDir(name: String!) repeatable on
| OBJECT
| INTERFACE

directive @include(if: Boolean!)
on FIELD
| FRAGMENT_SPREAD
Expand Down