Skip to content

Commit

Permalink
DSLField and DSLFragment inherit new DSLSelection class
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz committed Aug 28, 2021
1 parent b613fef commit d505962
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 66 deletions.
172 changes: 106 additions & 66 deletions gql/dsl.py
Expand Up @@ -245,7 +245,7 @@ def __init__(
self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions()

# Concatenate fields without and with alias
all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields(
all_fields: Tuple["DSLSelection", ...] = DSLField.get_aliased_fields(
fields, fields_with_alias
)

Expand All @@ -265,7 +265,7 @@ def __init__(
)

self.selection_set: SelectionSetNode = SelectionSetNode(
selections=FrozenList(DSLField.get_ast_fields(all_fields))
selections=FrozenList(DSLSelection.get_ast_fields(all_fields))
)


Expand Down Expand Up @@ -397,56 +397,35 @@ def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._type!r}>"


class DSLField:
"""The DSLField represents a GraphQL field for the DSL code.
Instances of this class are generated for you automatically as attributes
of the :class:`DSLType`
class DSLSelection(ABC):
"""DSLSelection is an abstract class which define the
:meth:`select <gql.dsl.DSLSelection.select>` method to select
children fields in the query.
If this field contains children fields, then you need to select which ones
you want in the request using the :meth:`select <gql.dsl.DSLField.select>`
method.
subclasses:
:class:`DSLField`
:class:`DSLFragment`
"""

_type: Union[GraphQLObjectType, GraphQLInterfaceType]
ast_field: FieldNode
field: GraphQLField

def __init__(
self,
name: str,
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
graphql_field: GraphQLField,
):
"""Initialize the DSLField.
.. warning::
Don't instantiate this class yourself.
Use attributes of the :class:`DSLType` instead.
:param name: the name of the field
:param graphql_type: the GraphQL type definition from the schema
:param graphql_field: the GraphQL field definition from the schema
"""
self._type = graphql_type
self.field = graphql_field
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
log.debug(f"Creating {self!r}")
ast_field: Union[FieldNode, InlineFragmentNode]

@staticmethod
def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]:
def get_ast_fields(
fields: Iterable["DSLSelection"],
) -> List[Union[FieldNode, InlineFragmentNode]]:
"""
:meta private:
Equivalent to: :code:`[field.ast_field for field in fields]`
But with a type check for each field in the list.
:raises TypeError: if any of the provided fields are not instances
of the :class:`DSLField` class.
of the :class:`DSLSelection` class.
"""
ast_fields = []
for field in fields:
if isinstance(field, DSLField):
if isinstance(field, DSLSelection):
ast_fields.append(field.ast_field)
else:
raise TypeError(f'Received incompatible field: "{field}".')
Expand All @@ -455,8 +434,8 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]:

@staticmethod
def get_aliased_fields(
fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"]
) -> Tuple["DSLField", ...]:
fields: Iterable["DSLSelection"], fields_with_alias: Dict[str, "DSLField"]
) -> Tuple["DSLSelection", ...]:
"""
:meta private:
Expand All @@ -471,30 +450,32 @@ def get_aliased_fields(
)

def select(
self, *fields: "DSLField", **fields_with_alias: "DSLField"
) -> "DSLField":
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
) -> "DSLSelection":
r"""Select the new children fields
that we want to receive in the request.
If used multiple times, we will add the new children fields
to the existing children fields.
:param \*fields: new children fields
:type \*fields: DSLField
:type \*fields: DSLSelection (DSLField or DSLFragment)
:param \**fields_with_alias: new children fields with alias as key
:type \**fields_with_alias: DSLField
:return: itself
:raises TypeError: if any of the provided fields are not instances
of the :class:`DSLField` class.
of the :class:`DSLSelection` class.
"""

# Concatenate fields without and with alias
added_fields: Tuple["DSLField", ...] = self.get_aliased_fields(
added_fields: Tuple["DSLSelection", ...] = self.get_aliased_fields(
fields, fields_with_alias
)

added_selections: List[FieldNode] = self.get_ast_fields(added_fields)
added_selections: List[
Union[FieldNode, InlineFragmentNode]
] = self.get_ast_fields(added_fields)

current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set

Expand All @@ -511,6 +492,58 @@ def select(

return self

@property
def type_name(self):
""":meta private:"""
return self._type.name

def __str__(self) -> str:
return print_ast(self.ast_field)


class DSLField(DSLSelection):
"""The DSLField represents a GraphQL field for the DSL code.
Instances of this class are generated for you automatically as attributes
of the :class:`DSLType`
If this field contains children fields, then you need to select which ones
you want in the request using the :meth:`select <gql.dsl.DSLField.select>`
method.
"""

ast_field: FieldNode
field: GraphQLField

def __init__(
self,
name: str,
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
graphql_field: GraphQLField,
):
"""Initialize the DSLField.
.. warning::
Don't instantiate this class yourself.
Use attributes of the :class:`DSLType` instead.
:param name: the name of the field
:param graphql_type: the GraphQL type definition from the schema
:param graphql_field: the GraphQL field definition from the schema
"""
self._type = graphql_type
self.field = graphql_field
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
log.debug(f"Creating {self!r}")

def select(
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
) -> "DSLField":
"""Calling :meth:`select <gql.dsl.DSLSelection.select>` method with
corrected typing hints
"""
return cast("DSLField", super().select(*fields, **fields_with_alias))

def __call__(self, **kwargs) -> "DSLField":
return self.args(**kwargs)

Expand All @@ -519,7 +552,7 @@ def alias(self, alias: str) -> "DSLField":
.. note::
You can also pass the alias directly at the
:meth:`select <gql.dsl.DSLField.select>` method.
:meth:`select <gql.dsl.DSLSelection.select>` method.
:code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to:
:code:`ds.Query.human.select(ds.Character.name.alias("my_name"))`
Expand Down Expand Up @@ -579,34 +612,41 @@ def _get_argument(self, name: str) -> GraphQLArgument:

return arg

@property
def type_name(self):
""":meta private:"""
return self._type.name
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} {self._type.name}"
f"::{self.ast_field.name.value}>"
)

def __str__(self) -> str:
return print_ast(self.ast_field)

def __repr__(self) -> str:
name = self._type.name
try:
name += f"::{self.ast_field.name.value}"
except AttributeError:
pass
return f"<{self.__class__.__name__} {name}>"
class DSLFragment(DSLSelection):

ast_field: InlineFragmentNode

class DSLFragment(DSLField):
def __init__(
self, type_condition: Optional[DSLType] = None,
):
self.ast_field = InlineFragmentNode() # type: ignore
if type_condition:
self.on(type_condition)
def __init__(self):
self.ast_field = InlineFragmentNode()

def select(
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
) -> "DSLFragment":
"""Calling :meth:`select <gql.dsl.DSLSelection.select>` method with
corrected typing hints
"""
return cast("DSLFragment", super().select(*fields, **fields_with_alias))

def on(self, type_condition: DSLType):
self._type = type_condition._type
self.ast_field.type_condition = NamedTypeNode( # type: ignore
self.ast_field.type_condition = NamedTypeNode(
name=NameNode(value=self._type.name)
)
return self

def __repr__(self) -> str:
type_info = ""

try:
type_info += f" on {self._type.name}"
except AttributeError:
pass

return f"<{self.__class__.__name__}{type_info}>"
7 changes: 7 additions & 0 deletions tests/starwars/test_dsl.py
Expand Up @@ -435,6 +435,13 @@ def test_inline_fragments(ds):
assert query == str(query_dsl)


def test_inline_fragments_repr(ds):

assert repr(DSLFragment()) == "<DSLFragment>"

assert repr(DSLFragment().on(ds.Droid)) == "<DSLFragment on Droid>"


def test_dsl_query_all_fields_should_be_instances_of_DSLField():
with pytest.raises(
TypeError, match="fields must be instances of DSLField. Received type:"
Expand Down

0 comments on commit d505962

Please sign in to comment.