diff --git a/report_builder/api/tests.py b/report_builder/api/tests.py new file mode 100644 index 00000000..4e7e84aa --- /dev/null +++ b/report_builder/api/tests.py @@ -0,0 +1,70 @@ +from django.test import TestCase, override_settings +from django.urls import reverse +from django.contrib.contenttypes.models import ContentType +from django.contrib.auth import get_user_model +from rest_framework.test import APIClient +import json + + +class ApiTestCase(TestCase): + + def setUp(self): + um = get_user_model() + self.superuser = um.objects.create_superuser('su', 'su@example.com', 'su') + self.regularuser = um.objects.create_user('user', 'user@example.com', 'user') + self.client = APIClient() + + def get_json(self, url): + self.client.login(username='su', password='su') + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + content = json.loads(response.content) + return content + + def post_json(self, url, data): + response = self.client.post( + url, + data, + format='json' + ) + self.assertEqual(response.status_code, 200) + content = json.loads(response.content) + return content + + def test_endpoint_not_accessable_by_regular_user(self): + self.client.login(username='user', password='user') + response = self.client.get('/report_builder/api/contenttypes/') + self.assertEqual(response.status_code, 403) + + def get_related_fields_for_ct(self, app_label, model_name): + self.client.login(username='su', password='su') + ct = ContentType.objects.get_by_natural_key(app_label, model_name) + content = self.post_json( + reverse('related_fields'), + {'field': '', 'model': ct.id, 'path': ''}, + ) + return content + + def test_related_fields(self): + related_fields = self.get_related_fields_for_ct('demo_models', 'child') + self.assertEqual(len(related_fields), 1) + self.assertEqual(related_fields[0]['field_name'], 'parent') + self.assertTrue(related_fields[0]['included_model']) + + @override_settings(REPORT_BUILDER_EXCLUDE=['demo_models.person']) + def test_related_fields_exclude(self): + related_fields = self.get_related_fields_for_ct('demo_models', 'child') + self.assertEqual(len(related_fields), 1) + self.assertEqual(related_fields[0]['field_name'], 'parent') + self.assertFalse(related_fields[0]['included_model']) + + def test_get_all_content_types(self): + num_content_types = ContentType.objects.count() + response = self.get_json('/report_builder/api/contenttypes/') + self.assertEqual(len(response), num_content_types) + + @override_settings(REPORT_BUILDER_EXCLUDE=['demo_models.person']) + def test_get_content_types_with_exclude(self): + num_content_types = ContentType.objects.count() + response = self.get_json('/report_builder/api/contenttypes/') + self.assertEqual(len(response), num_content_types - 1) diff --git a/report_builder/api/views.py b/report_builder/api/views.py index 331a03c6..3a0cb992 100644 --- a/report_builder/api/views.py +++ b/report_builder/api/views.py @@ -4,18 +4,16 @@ from django.http import JsonResponse from django.utils.functional import cached_property from django.conf import settings -from rest_framework import viewsets, status +from rest_framework import viewsets from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.decorators import action -from rest_framework.response import Response from rest_framework.permissions import IsAdminUser -from ..models import Report, Format, FilterField +from ..models import Report, Format, FilterField, get_allowed_models from .serializers import ( ReportNestedSerializer, ReportSerializer, FormatSerializer, FilterFieldSerializer, ContentTypeSerializer) from ..mixins import GetFieldsMixin, DataExportMixin -from django.core import serializers from ..utils import duplicate @@ -35,6 +33,7 @@ class ReportBuilderViewMixin: permission_classes = (IsAdminUser,) pagination_class = None + class ConfigView(ReportBuilderViewMixin, APIView): def get(self, request): data = { @@ -58,9 +57,11 @@ class ContentTypeViewSet(ReportBuilderViewMixin, viewsets.ReadOnlyModelViewSet): """ Read only view of content types. Used to populate choices for new report root model. """ - queryset = ContentType.objects.all() serializer_class = ContentTypeSerializer + def get_queryset(self): + return get_allowed_models() + class ReportViewSet(ReportBuilderViewMixin, viewsets.ModelViewSet): queryset = Report.objects.all() @@ -101,8 +102,6 @@ def copy_report(self, request, pk=None): serializer = ReportNestedSerializer(new_report) return JsonResponse(serializer.data) - - class RelatedFieldsView(ReportBuilderViewMixin, GetFieldsMixin, APIView): """ Get related fields from an ORM model """ @@ -123,26 +122,18 @@ def post(self, request): result = [] for new_field in new_fields: included_model = True - split_name = new_field.name.split(':') - if len(split_name) == 1: - split_name.append('') - split_name[1] = split_name[0] - split_name[0] = False - model_information = split_name[1] - else: - model_information = split_name[0] + "." + split_name[1] - app_label = split_name[0] - model_name = split_name[1] + related_model = new_field.related_model + label = related_model._meta.label_lower + app_label, model_name = label.split('.') if getattr(settings, 'REPORT_BUILDER_INCLUDE', False): includes = getattr(settings, 'REPORT_BUILDER_INCLUDE') # If it is not included as 'foo' and not as 'demo_models.foo' - if (model_name not in includes and - model_information not in includes): + if (model_name not in includes and label not in includes): included_model = False if getattr(settings, 'REPORT_BUILDER_EXCLUDE', False): excludes = getattr(settings, 'REPORT_BUILDER_EXCLUDE') # If it is excluded as 'foo' and as 'demo_models.foo' - if (model_name in excludes or model_information in excludes): + if (model_name in excludes or label in excludes): included_model = False verbose_name = getattr(new_field, 'verbose_name', None) if verbose_name is None: diff --git a/report_builder/urls.py b/report_builder/urls.py index c76787e3..55785d23 100644 --- a/report_builder/urls.py +++ b/report_builder/urls.py @@ -10,7 +10,7 @@ router.register(r'report', api_views.ReportNestedViewSet) router.register(r'formats', api_views.FormatViewSet) router.register(r'filterfields', api_views.FilterFieldViewSet) -router.register(r'contenttypes', api_views.ContentTypeViewSet) +router.register(r'contenttypes', api_views.ContentTypeViewSet, base_name='contenttypes') urlpatterns = [ url(r'^report/(?P\d+)/download_file/$', views.DownloadFileView.as_view(), name="report_download_file"),