Skip to content

Commit

Permalink
Simple implementation of DSL inline fragments
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Aug 26, 2021
1 parent ca4021d commit b613fef
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
37 changes: 29 additions & 8 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GraphQLObjectType,
GraphQLSchema,
GraphQLWrappingType,
InlineFragmentNode,
ListTypeNode,
ListValueNode,
NamedTypeNode,
Expand Down Expand Up @@ -407,6 +408,10 @@ class DSLField:
method.
"""

_type: Union[GraphQLObjectType, GraphQLInterfaceType]
ast_field: FieldNode
field: GraphQLField

def __init__(
self,
name: str,
Expand All @@ -423,11 +428,9 @@ def __init__(
:param graphql_type: the GraphQL type definition from the schema
:param graphql_field: the GraphQL field definition from the schema
"""
self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type
self.field: GraphQLField = graphql_field
self.ast_field: FieldNode = FieldNode(
name=NameNode(value=name), arguments=FrozenList()
)
self._type = graphql_type
self.field = graphql_field
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
log.debug(f"Creating {self!r}")

@staticmethod
Expand Down Expand Up @@ -585,7 +588,25 @@ def __str__(self) -> str:
return print_ast(self.ast_field)

def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} {self._type.name}"
f"::{self.ast_field.name.value}>"
name = self._type.name
try:
name += f"::{self.ast_field.name.value}"
except AttributeError:
pass
return f"<{self.__class__.__name__} {name}>"


class DSLFragment(DSLField):
def __init__(
self, type_condition: Optional[DSLType] = None,
):
self.ast_field = InlineFragmentNode() # type: ignore
if type_condition:
self.on(type_condition)

def on(self, type_condition: DSLType):
self._type = type_condition._type
self.ast_field.type_condition = NamedTypeNode( # type: ignore
name=NameNode(value=self._type.name)
)
return self
19 changes: 19 additions & 0 deletions tests/starwars/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from gql import Client
from gql.dsl import (
DSLFragment,
DSLMutation,
DSLQuery,
DSLSchema,
Expand Down Expand Up @@ -416,6 +417,24 @@ def test_multiple_operations(ds):
)


def test_inline_fragments(ds):
query = """hero(episode: JEDI) {
name
... on Droid {
primaryFunction
}
... on Human {
homePlanet
}
}"""
query_dsl = ds.Query.hero.args(episode=6).select(
ds.Character.name,
DSLFragment().on(ds.Droid).select(ds.Droid.primaryFunction),
DSLFragment().on(ds.Human).select(ds.Human.homePlanet),
)
assert query == str(query_dsl)


def test_dsl_query_all_fields_should_be_instances_of_DSLField():
with pytest.raises(
TypeError, match="fields must be instances of DSLField. Received type:"
Expand Down

0 comments on commit b613fef

Please sign in to comment.