diff --git a/report_builder/api/views.py b/report_builder/api/views.py index d544f6c5..2ff0d728 100644 --- a/report_builder/api/views.py +++ b/report_builder/api/views.py @@ -44,7 +44,7 @@ class ContentTypeViewSet(ReportBuilderViewMixin, viewsets.ReadOnlyModelViewSet): """ Read only view of content types. Used to populate choices for new report root model. """ - queryset = ContentType.objects.all() + queryset = Report.allowed_models() serializer_class = ContentTypeSerializer permission_classes = (IsAdminUser,) diff --git a/report_builder/models.py b/report_builder/models.py index e56aa4e5..31cf5f13 100644 --- a/report_builder/models.py +++ b/report_builder/models.py @@ -2,9 +2,10 @@ from django.contrib.contenttypes.models import ContentType from django.conf import settings from django.core.urlresolvers import reverse -from django.core.exceptions import ValidationError, ObjectDoesNotExist +from django.core.exceptions import ValidationError, ObjectDoesNotExist, ImproperlyConfigured from django.utils.safestring import mark_safe from django.utils.functional import cached_property +from django.utils.lru_cache import lru_cache from django.db import models from django.db.models import Avg, Min, Max, Count, Sum, F from django.db.models.fields import FieldDoesNotExist @@ -22,29 +23,51 @@ AUTH_USER_MODEL = getattr(settings, 'AUTH_USER_MODEL', 'auth.User') +@lru_cache(maxsize=None) +def get_model_contenttype(name): + if "." in name: + # Fully qualified model name + split_name = name.lower().split('.') + # Negative indexing allows, eg. "django.contrib.auth.User" + model_query = { + 'app_label': split_name[-2], + 'model': split_name[-1] + } + model_name = "%s.%s" % (model_query['app_label'], model_query['model']) + else: + # short model name + model_query = {'model': name.lower()} + model_name = model_query['model'] + + # ensure it exists, and warn early if there's an issue + try: + model_ct = ContentType.objects.get(**model_query) + except ContentType.DoesNotExist: + raise ImproperlyConfigured( + "REPORT_BUILDER: Model '%s' (from '%s') could not be found." % (model_name, name)) + except ContentType.MultipleObjectsReturned: + possible_cts = ContentType.objects.filter(**model_query).values_list('app_label', 'model') + possible_cts = [".".join(ct) for ct in possible_cts] + + raise ImproperlyConfigured( + "REPORT_BUILDER: Model '%s' is ambiguous. Possible values: %s" % (name, str(possible_cts))) + else: + return model_ct + + def get_allowed_models(): models = ContentType.objects.all() if getattr(settings, 'REPORT_BUILDER_INCLUDE', False): - all_model_names = [] - additional_models = [] + ct_pks = [] # pks of all included model contenttypes for element in settings.REPORT_BUILDER_INCLUDE: - split_element = element.split('.') - if len(split_element) == 2: - additional_models.append(models.filter(app_label=split_element[0], model=split_element[1])) - else: - all_model_names.append(element) - models = models.filter(model__in=all_model_names) - for additional_model in additional_models: - models = models | additional_model + ct_pks.append(get_model_contenttype(element).pk) + models = ContentType.objects.filter(pk__in=ct_pks) + if getattr(settings, 'REPORT_BUILDER_EXCLUDE', False): - all_model_names = [] + ct_pks = [] # pks of all excluded model contenttypes for element in settings.REPORT_BUILDER_EXCLUDE: - split_element = element.split('.') - if len(split_element) == 2: - models = models.exclude(app_label=split_element[0], model=split_element[1]) - else: - all_model_names.append(element) - models = models.exclude(model__in=all_model_names) + ct_pks.append(get_model_contenttype(element).pk) + models = models.exclude(pk__in=ct_pks) return models diff --git a/report_builder/tests.py b/report_builder/tests.py index 023e894e..ca59c200 100644 --- a/report_builder/tests.py +++ b/report_builder/tests.py @@ -1,6 +1,7 @@ from django.contrib.contenttypes.models import ContentType from django.core import mail from django.core.urlresolvers import reverse +from django.core.exceptions import ImproperlyConfigured from django.db.models.query import QuerySet from django.test import TestCase from django.test.utils import override_settings @@ -155,6 +156,37 @@ def setUp(self): self.client = APIClient() self.client.login(username='testy', password='pass') + def test_bad_names(self): + settings.REPORT_BUILDER_EXCLUDE = None + + # 'bar' is ambiguous + settings.REPORT_BUILDER_INCLUDE = ('bar',) + with self.assertRaises(ImproperlyConfigured): + get_allowed_models() + + # app_label is invalid + settings.REPORT_BUILDER_INCLUDE = ('invalid.bar',) + with self.assertRaises(ImproperlyConfigured): + get_allowed_models() + settings.REPORT_BUILDER_INCLUDE = ('invalid.invalid',) + with self.assertRaises(ImproperlyConfigured): + get_allowed_models() + + # model is invalid + settings.REPORT_BUILDER_INCLUDE = ('invalid',) + with self.assertRaises(ImproperlyConfigured): + get_allowed_models() + settings.REPORT_BUILDER_INCLUDE = ('demo_models.invalid',) + with self.assertRaises(ImproperlyConfigured): + get_allowed_models() + + # ensure duplicate models don't overlap + settings.REPORT_BUILDER_INCLUDE = ( + 'demo_models.bar', + 'demo_second_app.bar', + ) + self.assertEqual(len(get_allowed_models()), 2) + def test_get_allowed_models_for_include(self): pre_include_duplicates = find_duplicates_in_contexttype() settings.REPORT_BUILDER_INCLUDE = (