Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSL implementation of fragments #235

Merged
merged 11 commits into from Sep 12, 2021
37 changes: 29 additions & 8 deletions gql/dsl.py
Expand Up @@ -17,6 +17,7 @@
GraphQLObjectType,
GraphQLSchema,
GraphQLWrappingType,
InlineFragmentNode,
ListTypeNode,
ListValueNode,
NamedTypeNode,
Expand Down Expand Up @@ -407,6 +408,10 @@ class DSLField:
method.
"""

_type: Union[GraphQLObjectType, GraphQLInterfaceType]
ast_field: FieldNode
field: GraphQLField

def __init__(
self,
name: str,
Expand All @@ -423,11 +428,9 @@ def __init__(
:param graphql_type: the GraphQL type definition from the schema
:param graphql_field: the GraphQL field definition from the schema
"""
self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type
self.field: GraphQLField = graphql_field
self.ast_field: FieldNode = FieldNode(
name=NameNode(value=name), arguments=FrozenList()
)
self._type = graphql_type
self.field = graphql_field
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
log.debug(f"Creating {self!r}")

@staticmethod
Expand Down Expand Up @@ -585,7 +588,25 @@ def __str__(self) -> str:
return print_ast(self.ast_field)

def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} {self._type.name}"
f"::{self.ast_field.name.value}>"
name = self._type.name
try:
name += f"::{self.ast_field.name.value}"
except AttributeError:
pass
return f"<{self.__class__.__name__} {name}>"


class DSLFragment(DSLField):
def __init__(
self, type_condition: Optional[DSLType] = None,
):
self.ast_field = InlineFragmentNode() # type: ignore
if type_condition:
self.on(type_condition)

def on(self, type_condition: DSLType):
self._type = type_condition._type
self.ast_field.type_condition = NamedTypeNode( # type: ignore
name=NameNode(value=self._type.name)
)
return self
19 changes: 19 additions & 0 deletions tests/starwars/test_dsl.py
Expand Up @@ -15,6 +15,7 @@

from gql import Client
from gql.dsl import (
DSLFragment,
DSLMutation,
DSLQuery,
DSLSchema,
Expand Down Expand Up @@ -416,6 +417,24 @@ def test_multiple_operations(ds):
)


def test_inline_fragments(ds):
query = """hero(episode: JEDI) {
name
... on Droid {
primaryFunction
}
... on Human {
homePlanet
}
}"""
query_dsl = ds.Query.hero.args(episode=6).select(
ds.Character.name,
DSLFragment().on(ds.Droid).select(ds.Droid.primaryFunction),
DSLFragment().on(ds.Human).select(ds.Human.homePlanet),
)
assert query == str(query_dsl)


def test_dsl_query_all_fields_should_be_instances_of_DSLField():
with pytest.raises(
TypeError, match="fields must be instances of DSLField. Received type:"
Expand Down