From d505962abdb4795e76f994ef93a80b7e3fae8df4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 28 Aug 2021 18:16:17 +0200 Subject: [PATCH] DSLField and DSLFragment inherit new DSLSelection class --- gql/dsl.py | 172 +++++++++++++++++++++++-------------- tests/starwars/test_dsl.py | 7 ++ 2 files changed, 113 insertions(+), 66 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index bcbb0be7..f0d64f0f 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -245,7 +245,7 @@ def __init__( self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions() # Concatenate fields without and with alias - all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( + all_fields: Tuple["DSLSelection", ...] = DSLField.get_aliased_fields( fields, fields_with_alias ) @@ -265,7 +265,7 @@ def __init__( ) self.selection_set: SelectionSetNode = SelectionSetNode( - selections=FrozenList(DSLField.get_ast_fields(all_fields)) + selections=FrozenList(DSLSelection.get_ast_fields(all_fields)) ) @@ -397,44 +397,23 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" -class DSLField: - """The DSLField represents a GraphQL field for the DSL code. - - Instances of this class are generated for you automatically as attributes - of the :class:`DSLType` +class DSLSelection(ABC): + """DSLSelection is an abstract class which define the + :meth:`select ` method to select + children fields in the query. - If this field contains children fields, then you need to select which ones - you want in the request using the :meth:`select ` - method. + subclasses: + :class:`DSLField` + :class:`DSLFragment` """ _type: Union[GraphQLObjectType, GraphQLInterfaceType] - ast_field: FieldNode - field: GraphQLField - - def __init__( - self, - name: str, - graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], - graphql_field: GraphQLField, - ): - """Initialize the DSLField. - - .. warning:: - Don't instantiate this class yourself. - Use attributes of the :class:`DSLType` instead. - - :param name: the name of the field - :param graphql_type: the GraphQL type definition from the schema - :param graphql_field: the GraphQL field definition from the schema - """ - self._type = graphql_type - self.field = graphql_field - self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) - log.debug(f"Creating {self!r}") + ast_field: Union[FieldNode, InlineFragmentNode] @staticmethod - def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: + def get_ast_fields( + fields: Iterable["DSLSelection"], + ) -> List[Union[FieldNode, InlineFragmentNode]]: """ :meta private: @@ -442,11 +421,11 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: But with a type check for each field in the list. :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLField` class. + of the :class:`DSLSelection` class. """ ast_fields = [] for field in fields: - if isinstance(field, DSLField): + if isinstance(field, DSLSelection): ast_fields.append(field.ast_field) else: raise TypeError(f'Received incompatible field: "{field}".') @@ -455,8 +434,8 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: @staticmethod def get_aliased_fields( - fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"] - ) -> Tuple["DSLField", ...]: + fields: Iterable["DSLSelection"], fields_with_alias: Dict[str, "DSLField"] + ) -> Tuple["DSLSelection", ...]: """ :meta private: @@ -471,8 +450,8 @@ def get_aliased_fields( ) def select( - self, *fields: "DSLField", **fields_with_alias: "DSLField" - ) -> "DSLField": + self, *fields: "DSLSelection", **fields_with_alias: "DSLField" + ) -> "DSLSelection": r"""Select the new children fields that we want to receive in the request. @@ -480,21 +459,23 @@ def select( to the existing children fields. :param \*fields: new children fields - :type \*fields: DSLField + :type \*fields: DSLSelection (DSLField or DSLFragment) :param \**fields_with_alias: new children fields with alias as key :type \**fields_with_alias: DSLField :return: itself :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLField` class. + of the :class:`DSLSelection` class. """ # Concatenate fields without and with alias - added_fields: Tuple["DSLField", ...] = self.get_aliased_fields( + added_fields: Tuple["DSLSelection", ...] = self.get_aliased_fields( fields, fields_with_alias ) - added_selections: List[FieldNode] = self.get_ast_fields(added_fields) + added_selections: List[ + Union[FieldNode, InlineFragmentNode] + ] = self.get_ast_fields(added_fields) current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set @@ -511,6 +492,58 @@ def select( return self + @property + def type_name(self): + """:meta private:""" + return self._type.name + + def __str__(self) -> str: + return print_ast(self.ast_field) + + +class DSLField(DSLSelection): + """The DSLField represents a GraphQL field for the DSL code. + + Instances of this class are generated for you automatically as attributes + of the :class:`DSLType` + + If this field contains children fields, then you need to select which ones + you want in the request using the :meth:`select ` + method. + """ + + ast_field: FieldNode + field: GraphQLField + + def __init__( + self, + name: str, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + graphql_field: GraphQLField, + ): + """Initialize the DSLField. + + .. warning:: + Don't instantiate this class yourself. + Use attributes of the :class:`DSLType` instead. + + :param name: the name of the field + :param graphql_type: the GraphQL type definition from the schema + :param graphql_field: the GraphQL field definition from the schema + """ + self._type = graphql_type + self.field = graphql_field + self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) + log.debug(f"Creating {self!r}") + + def select( + self, *fields: "DSLSelection", **fields_with_alias: "DSLField" + ) -> "DSLField": + """Calling :meth:`select ` method with + corrected typing hints + """ + return cast("DSLField", super().select(*fields, **fields_with_alias)) + def __call__(self, **kwargs) -> "DSLField": return self.args(**kwargs) @@ -519,7 +552,7 @@ def alias(self, alias: str) -> "DSLField": .. note:: You can also pass the alias directly at the - :meth:`select ` method. + :meth:`select ` method. :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` @@ -579,34 +612,41 @@ def _get_argument(self, name: str) -> GraphQLArgument: return arg - @property - def type_name(self): - """:meta private:""" - return self._type.name + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} {self._type.name}" + f"::{self.ast_field.name.value}>" + ) - def __str__(self) -> str: - return print_ast(self.ast_field) - def __repr__(self) -> str: - name = self._type.name - try: - name += f"::{self.ast_field.name.value}" - except AttributeError: - pass - return f"<{self.__class__.__name__} {name}>" +class DSLFragment(DSLSelection): + ast_field: InlineFragmentNode -class DSLFragment(DSLField): - def __init__( - self, type_condition: Optional[DSLType] = None, - ): - self.ast_field = InlineFragmentNode() # type: ignore - if type_condition: - self.on(type_condition) + def __init__(self): + self.ast_field = InlineFragmentNode() + + def select( + self, *fields: "DSLSelection", **fields_with_alias: "DSLField" + ) -> "DSLFragment": + """Calling :meth:`select ` method with + corrected typing hints + """ + return cast("DSLFragment", super().select(*fields, **fields_with_alias)) def on(self, type_condition: DSLType): self._type = type_condition._type - self.ast_field.type_condition = NamedTypeNode( # type: ignore + self.ast_field.type_condition = NamedTypeNode( name=NameNode(value=self._type.name) ) return self + + def __repr__(self) -> str: + type_info = "" + + try: + type_info += f" on {self._type.name}" + except AttributeError: + pass + + return f"<{self.__class__.__name__}{type_info}>" diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index dbc06e07..cee7a20d 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -435,6 +435,13 @@ def test_inline_fragments(ds): assert query == str(query_dsl) +def test_inline_fragments_repr(ds): + + assert repr(DSLFragment()) == "" + + assert repr(DSLFragment().on(ds.Droid)) == "" + + def test_dsl_query_all_fields_should_be_instances_of_DSLField(): with pytest.raises( TypeError, match="fields must be instances of DSLField. Received type:"