diff --git a/dev-env-requirements.txt b/dev-env-requirements.txt index 4f70c88..52eed6f 100644 --- a/dev-env-requirements.txt +++ b/dev-env-requirements.txt @@ -1,6 +1,6 @@ -r requirements.txt -graphene==2.1.8 -graphene-django==2.7.1 +graphene==3.0b7 +graphene-django==3.0.0b7 pytest==4.6.3 pytest-django==3.5.0 pytest-cov==2.7.1 diff --git a/graphene_django_optimizer/field.py b/graphene_django_optimizer/field.py index c6bd310..48165cc 100644 --- a/graphene_django_optimizer/field.py +++ b/graphene_django_optimizer/field.py @@ -1,3 +1,4 @@ +import types from graphene.types.field import Field from graphene.types.unmountedtype import UnmountedType @@ -9,12 +10,12 @@ def field(field_type, *args, **kwargs): field_type = Field.mounted(field_type) optimization_hints = OptimizationHints(*args, **kwargs) - get_resolver = field_type.get_resolver + wrap_resolve = field_type.wrap_resolve - def get_optimized_resolver(parent_resolver): - resolver = get_resolver(parent_resolver) + def get_optimized_resolver(self, parent_resolver): + resolver = wrap_resolve(parent_resolver) resolver.optimization_hints = optimization_hints return resolver - field_type.get_resolver = get_optimized_resolver + field_type.wrap_resolve = types.MethodType(get_optimized_resolver, field_type) return field_type diff --git a/graphene_django_optimizer/query.py b/graphene_django_optimizer/query.py index 057125a..e64f041 100644 --- a/graphene_django_optimizer/query.py +++ b/graphene_django_optimizer/query.py @@ -8,20 +8,20 @@ from graphene.types.generic import GenericScalar from graphene.types.resolver import default_resolver from graphene_django import DjangoObjectType -from graphql import ResolveInfo -from graphql.execution.base import ( - get_field_def, -) +from graphql import GraphQLResolveInfo, GraphQLSchema +from graphql.execution.execute import get_field_def from graphql.language.ast import ( - FragmentSpread, - InlineFragment, - Variable, + FragmentSpreadNode, + InlineFragmentNode, + VariableNode, ) from graphql.type.definition import ( GraphQLInterfaceType, GraphQLUnionType, ) +from graphql.pyutils import Path + from .utils import is_iterable @@ -31,7 +31,7 @@ def query(queryset, info, **options): Arguments: - queryset (Django QuerySet object) - The queryset to be optimized - - info (GraphQL ResolveInfo object) - This is passed by the graphene-django resolve methods + - info (GraphQL GraphQLResolveInfo object) - This is passed by the graphene-django resolve methods - **options - optimization options/settings - disable_abort_only (boolean) - in case the objecttype contains any extra fields, then this will keep the "only" optimization enabled. @@ -54,7 +54,7 @@ def optimize(self, queryset): field_def = get_field_def(info.schema, info.parent_type, info.field_name) store = self._optimize_gql_selections( self._get_type(field_def), - info.field_asts[0], + info.field_nodes[0], # info.parent_type, ) return store.optimize_queryset(queryset) @@ -65,9 +65,16 @@ def _get_type(self, field_def): a_type = a_type.of_type return a_type + def _get_graphql_schema(self, schema): + if isinstance(schema, GraphQLSchema): + return schema + else: + return schema.graphql_schema + def _get_possible_types(self, graphql_type): if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)): - return self.root_info.schema.get_possible_types(graphql_type) + graphql_schema = self._get_graphql_schema(self.root_info.schema) + return graphql_schema.get_possible_types(graphql_type) else: return (graphql_type,) @@ -80,7 +87,8 @@ def _get_base_model(self, graphql_types): def handle_inline_fragment(self, selection, schema, possible_types, store): fragment_type_name = selection.type_condition.name.value - fragment_type = schema.get_type(fragment_type_name) + graphql_schema = self._get_graphql_schema(schema) + fragment_type = graphql_schema.get_type(fragment_type_name) fragment_possible_types = self._get_possible_types(fragment_type) for fragment_possible_type in fragment_possible_types: fragment_model = fragment_possible_type.graphene_type._meta.model @@ -120,14 +128,16 @@ def _optimize_gql_selections(self, field_type, field_ast): return store optimized_fields_by_model = {} schema = self.root_info.schema - graphql_type = schema.get_graphql_type(field_type.graphene_type) + graphql_schema = self._get_graphql_schema(schema) + graphql_type = graphql_schema.get_type(field_type.name) + possible_types = self._get_possible_types(graphql_type) for selection in selection_set.selections: - if isinstance(selection, InlineFragment): + if isinstance(selection, InlineFragmentNode): self.handle_inline_fragment(selection, schema, possible_types, store) else: name = selection.name.value - if isinstance(selection, FragmentSpread): + if isinstance(selection, FragmentSpreadNode): self.handle_fragment_spread(store, name, field_type) else: for possible_type in possible_types: @@ -176,7 +186,7 @@ def _optimize_field(self, store, model, selection, field_def, parent_type): store.abort_only_optimization() def _optimize_field_by_name(self, store, model, selection, field_def): - name = self._get_name_from_resolver(field_def.resolver) + name = self._get_name_from_resolver(field_def.resolve) if not name: return False model_field = self._get_model_field_from_name(model, name) @@ -215,7 +225,7 @@ def _get_optimization_hints(self, resolver): return getattr(resolver, "optimization_hints", None) def _get_value(self, info, value): - if isinstance(value, Variable): + if isinstance(value, VariableNode): var_name = value.name.value value = info.variable_values.get(var_name) return value @@ -225,7 +235,7 @@ def _get_value(self, info, value): return GenericScalar.parse_literal(value) def _optimize_field_by_hints(self, store, selection, field_def, parent_type): - optimization_hints = self._get_optimization_hints(field_def.resolver) + optimization_hints = self._get_optimization_hints(field_def.resolve) if not optimization_hints: return False info = self._create_resolve_info( @@ -316,17 +326,19 @@ def _is_foreign_key_id(self, model_field, name): ) def _create_resolve_info(self, field_name, field_asts, return_type, parent_type): - return ResolveInfo( + return GraphQLResolveInfo( field_name, field_asts, return_type, parent_type, + Path(None, 0, None), schema=self.root_info.schema, fragments=self.root_info.fragments, root_value=self.root_info.root_value, operation=self.root_info.operation, variable_values=self.root_info.variable_values, context=self.root_info.context, + is_awaitable=self.root_info.is_awaitable, ) diff --git a/tests/graphql_utils.py b/tests/graphql_utils.py index 6b061e4..c6f1a8c 100644 --- a/tests/graphql_utils.py +++ b/tests/graphql_utils.py @@ -1,40 +1,38 @@ from graphql import ( - ResolveInfo, + GraphQLResolveInfo, Source, Undefined, parse, ) -from graphql.execution.base import ( +from graphql.execution.execute import ( ExecutionContext, - collect_fields, get_field_def, - get_operation_root_type, ) -from graphql.pyutils.default_ordered_dict import DefaultOrderedDict +from graphql.utilities import get_operation_root_type +from collections import defaultdict + +from graphql.pyutils import Path def create_execution_context(schema, request_string, variables=None): source = Source(request_string, "GraphQL request") document_ast = parse(source) - return ExecutionContext( + return ExecutionContext.build( schema, document_ast, root_value=None, context_value=None, - variable_values=variables, + raw_variable_values=variables, operation_name=None, - executor=None, middleware=None, - allow_subscriptions=False, ) def get_field_asts_from_execution_context(exe_context): - fields = collect_fields( - exe_context, + fields = exe_context.collect_fields( type, exe_context.operation.selection_set, - DefaultOrderedDict(list), + defaultdict(list), set(), ) # field_asts = next(iter(fields.values())) @@ -42,7 +40,7 @@ def get_field_asts_from_execution_context(exe_context): return field_asts -def create_resolve_info(schema, request_string, variables=None): +def create_resolve_info(schema, request_string, variables=None, return_type=None): exe_context = create_execution_context(schema, request_string, variables) parent_type = get_operation_root_type(schema, exe_context.operation) field_asts = get_field_asts_from_execution_context(exe_context) @@ -50,24 +48,26 @@ def create_resolve_info(schema, request_string, variables=None): field_ast = field_asts[0] field_name = field_ast.name.value - field_def = get_field_def(schema, parent_type, field_name) - if not field_def: - return Undefined - return_type = field_def.type + if return_type is None: + field_def = get_field_def(schema, parent_type, field_name) + if not field_def: + return Undefined + return_type = field_def.type # The resolve function's optional third argument is a context value that # is provided to every resolve function within an execution. It is commonly # used to represent an authenticated user, or request-specific caches. - context = exe_context.context_value - return ResolveInfo( + return GraphQLResolveInfo( field_name, field_asts, return_type, parent_type, - schema=schema, - fragments=exe_context.fragments, - root_value=exe_context.root_value, - operation=exe_context.operation, - variable_values=exe_context.variable_values, - context=context, + Path(None, 0, None), + schema, + exe_context.fragments, + exe_context.root_value, + exe_context.operation, + exe_context.variable_values, + exe_context.context_value, + exe_context.is_awaitable, ) diff --git a/tests/schema.py b/tests/schema.py index 66ffd9a..fc30203 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -99,6 +99,7 @@ class BaseItemType(OptimizedDjangoObjectType): class Meta: model = Item + fields = "__all__" @gql_optimizer.resolver_hints( model_field="children", @@ -110,6 +111,8 @@ def resolve_relay_all_children(root, info, **kwargs): class ItemNode(BaseItemType): class Meta: model = Item + fields = "__all__" + interfaces = ( graphene.relay.Node, ItemInterface, @@ -119,16 +122,19 @@ class Meta: class SomeOtherItemType(OptimizedDjangoObjectType): class Meta: model = SomeOtherItem + fields = "__all__" class OtherItemType(OptimizedDjangoObjectType): class Meta: model = OtherItem + fields = "__all__" class ItemType(BaseItemType): class Meta: model = Item + fields = "__all__" interfaces = (ItemInterface,) @@ -144,29 +150,34 @@ class DetailedInterface(graphene.Interface): class DetailedItemType(ItemType): class Meta: model = DetailedItem + fields = "__all__" interfaces = (ItemInterface, DetailedInterface) class RelatedItemType(ItemType): class Meta: model = RelatedItem + fields = "__all__" interfaces = (ItemInterface,) class ExtraDetailedItemType(DetailedItemType): class Meta: model = ExtraDetailedItem + fields = "__all__" interfaces = (ItemInterface,) class RelatedOneToManyItemType(OptimizedDjangoObjectType): class Meta: model = RelatedOneToManyItem + fields = "__all__" class UnrelatedModelType(OptimizedDjangoObjectType): class Meta: model = UnrelatedModel + fields = "__all__" interfaces = (DetailedInterface,) @@ -200,6 +211,21 @@ def resolve_other_items(root, info): return gql_optimizer.query(OtherItemType.objects.all(), info) -schema = graphene.Schema( - query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation -) +class Schema(graphene.Schema): + @property + def query_type(self): + return self.graphql_schema.get_type("Query") + + @property + def mutation_type(self): + return self.graphql_schema.get_type("Mutation") + + @property + def subscription_type(self): + return self.graphql_schema.get_type("Subscription") + + def get_type(self, _type): + return self.graphql_schema.get_type(_type) + + +schema = Schema(query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation) diff --git a/tests/test_field.py b/tests/test_field.py index d7a8957..71d617e 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -1,13 +1,13 @@ +import pytest import graphene_django_optimizer as gql_optimizer from .graphql_utils import create_resolve_info -from .models import ( - Item, -) +from .models import Item from .schema import schema from .test_utils import assert_query_equality +@pytest.mark.django_db def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_field(): info = create_resolve_info( schema, @@ -29,6 +29,7 @@ def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_ assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_optimize_with_only_hint(): info = create_resolve_info( schema, diff --git a/tests/test_query.py b/tests/test_query.py index a821ec5..fa7339d 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -15,7 +15,7 @@ from .test_utils import assert_query_equality -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_reduce_number_of_queries_by_using_select_related(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -39,7 +39,7 @@ def test_should_reduce_number_of_queries_by_using_select_related(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_reduce_number_of_queries_by_using_prefetch_related(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -64,7 +64,7 @@ def test_should_reduce_number_of_queries_by_using_prefetch_related(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_optimize_scalar_model_fields(): # Item.objects.create(name='foo') info = create_resolve_info( @@ -84,7 +84,7 @@ def test_should_optimize_scalar_model_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_optimize_scalar_foreign_key_model_fields(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -105,7 +105,7 @@ def test_should_optimize_scalar_foreign_key_model_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_not_try_to_optimize_non_model_fields(): # Item.objects.create(name='foo') info = create_resolve_info( @@ -125,7 +125,7 @@ def test_should_not_try_to_optimize_non_model_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_not_try_to_optimize_non_field_model_fields(): # Item.objects.create(name='foo') info = create_resolve_info( @@ -145,6 +145,7 @@ def test_should_not_try_to_optimize_non_field_model_fields(): assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_try_to_optimize_non_field_model_fields_when_disabling_abort_only(): # Item.objects.create(name='foo') info = create_resolve_info( @@ -164,7 +165,7 @@ def test_should_try_to_optimize_non_field_model_fields_when_disabling_abort_only assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_optimize_when_using_fragments(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -190,7 +191,7 @@ def test_should_optimize_when_using_fragments(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_prefetch_field_with_camel_case_name(): # item = Item.objects.create(name='foo') # Item.objects.create(name='bar', item=item) @@ -215,7 +216,7 @@ def test_should_prefetch_field_with_camel_case_name(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_select_nested_related_fields(): # parent = Item.objects.create(name='foo') # parent = Item.objects.create(name='bar', parent=parent) @@ -243,7 +244,7 @@ def test_should_select_nested_related_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_prefetch_nested_related_fields(): # parent = Item.objects.create(name='foo') # parent = Item.objects.create(name='bar', parent=parent) @@ -273,7 +274,7 @@ def test_should_prefetch_nested_related_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_prefetch_nested_select_related_field(): # parent = Item.objects.create(name='foo') # item = Item.objects.create(name='foobar') @@ -304,7 +305,7 @@ def test_should_prefetch_nested_select_related_field(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_select_nested_prefetch_related_field(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -333,7 +334,7 @@ def test_should_select_nested_prefetch_related_field(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_select_nested_prefetch_and_select_related_fields(): # parent = Item.objects.create(name='foo') # item = Item.objects.create(name='bar_item') @@ -368,7 +369,7 @@ def test_should_select_nested_prefetch_and_select_related_fields(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_fetch_fields_of_related_field(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -391,7 +392,7 @@ def test_should_fetch_fields_of_related_field(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_fetch_fields_of_prefetched_field(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -417,7 +418,7 @@ def test_should_fetch_fields_of_prefetched_field(): assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_fetch_child_model_field_for_interface_field(): # Item.objects.create(name='foo') # ExtraDetailedItem.objects.create(name='foo', extra_detail='test') @@ -443,7 +444,7 @@ def test_should_fetch_child_model_field_for_interface_field(): @pytest.mark.skip(reason="will be tested in the future") -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_fetch_field_of_child_model_when_parent_has_no_optimized_field(): # Item.objects.create(name='foo') # DetailedItem.objects.create(name='foo', item_type='test') @@ -466,6 +467,7 @@ def test_should_fetch_field_of_child_model_when_parent_has_no_optimized_field(): assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_fetch_field_inside_interface_fragment(): info = create_resolve_info( schema, @@ -488,6 +490,7 @@ def test_should_fetch_field_inside_interface_fragment(): assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_use_nested_prefetch_related_while_also_selecting_only_required_fields(): info = create_resolve_info( schema, @@ -565,7 +568,7 @@ def test_should_check_reverse_relations_add_foreign_key(): assert len(expected_query_capture) == len(optimized_query_capture) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_only_use_the_only_and_not_select_related(): info = create_resolve_info( schema, diff --git a/tests/test_relay.py b/tests/test_relay.py index 1e003f4..bfd75e8 100644 --- a/tests/test_relay.py +++ b/tests/test_relay.py @@ -3,9 +3,7 @@ import graphene_django_optimizer as gql_optimizer from .graphql_utils import create_resolve_info -from .models import ( - Item, -) +from .models import Item from .schema import schema from .test_utils import assert_query_equality @@ -37,6 +35,7 @@ def test_should_return_valid_result_in_a_relay_query(): assert result.data["relayItems"]["edges"][0]["node"]["name"] == "foo" +@pytest.mark.django_db def test_should_reduce_number_of_queries_in_relay_schema_by_using_select_related(): info = create_resolve_info( schema, @@ -62,6 +61,7 @@ def test_should_reduce_number_of_queries_in_relay_schema_by_using_select_related assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_reduce_number_of_queries_in_relay_schema_by_using_prefetch_related(): info = create_resolve_info( schema, @@ -88,6 +88,7 @@ def test_should_reduce_number_of_queries_in_relay_schema_by_using_prefetch_relat assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_optimize_query_by_only_requesting_id_field(): try: from django.db.models import DEFERRED # noqa: F401 diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 7a063a4..bbeb1fa 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -4,14 +4,12 @@ import graphene_django_optimizer as gql_optimizer from .graphql_utils import create_resolve_info -from .models import ( - Item, -) +from .models import Item from .schema import schema from .test_utils import assert_query_equality -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_resolver(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -39,7 +37,7 @@ def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_ assert_query_equality(items, optimized_items) -# @pytest.mark.django_db +@pytest.mark.django_db def test_should_optimize_with_prefetch_related_as_a_string(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) @@ -62,6 +60,7 @@ def test_should_optimize_with_prefetch_related_as_a_string(): assert_query_equality(items, optimized_items) +@pytest.mark.django_db def test_should_optimize_with_prefetch_related_as_a_function(): # parent = Item.objects.create(name='foo') # Item.objects.create(name='bar', parent=parent) diff --git a/tests/test_types.py b/tests/test_types.py index 345efd5..706151f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -25,9 +25,9 @@ def test_should_optimize_the_single_node(mocked_optimizer): } } """, + return_type=schema.graphql_schema.get_type("SomeOtherItemType"), ) - info.return_type = schema.get_type("SomeOtherItemType") result = SomeOtherItemType.get_node(info, 7) assert result, "Expected the item to be found and returned" @@ -55,9 +55,9 @@ def test_should_return_none_when_node_is_not_resolved(mocked_optimizer): } } """, + return_type=schema.graphql_schema.get_type("SomeOtherItemType"), ) - info.return_type = schema.get_type("SomeOtherItemType") qs = SomeOtherItem.objects mocked_optimizer.return_value = qs @@ -84,9 +84,9 @@ def test_mutating_should_not_optimize(mocked_optimizer): } } """, + return_type=schema.graphql_schema.get_type("SomeOtherItemType"), ) - info.return_type = schema.get_type("SomeOtherItemType") result = DummyItemMutation.mutate(info, to_global_id("ItemNode", 7)) assert result assert result.pk == 7 @@ -111,9 +111,9 @@ def test_should_optimize_the_queryset(mocked_optimizer): } } """, + return_type=schema.graphql_schema.get_type("SomeOtherItemType"), ) - info.return_type = schema.get_type("SomeOtherItemType") qs = SomeOtherItem.objects.filter(pk=7) result = SomeOtherItemType.get_queryset(qs, info).get()