diff --git a/incident/api/serializers.py b/incident/api/serializers.py index 55ef9079b..89e842822 100644 --- a/incident/api/serializers.py +++ b/incident/api/serializers.py @@ -99,7 +99,27 @@ def get_url(self, obj): return obj.get_full_url() -class BaseIncidentSerializer(serializers.Serializer): +class VariableFieldSerializer(serializers.Serializer): + """A serializer that takes a set of field names its context with + the key `requested_fields` that controls what fields should be + returned. + + """ + + def __init__(self, *args, **kwargs): + requested_fields = kwargs.get('context', {}).get('requested_fields', set()) + + super().__init__(*args, **kwargs) + + if requested_fields: + # Drop any fields that are not specified in + # `requested_fields`. + existing = set(self.fields) + for field_name in existing - requested_fields: + self.fields.pop(field_name) + + +class BaseIncidentSerializer(VariableFieldSerializer): title = serializers.CharField() url = serializers.SerializerMethodField() first_published_at = serializers.DateTimeField() @@ -153,21 +173,6 @@ class BaseIncidentSerializer(serializers.Serializer): legal_order_type = serializers.CharField(source='get_legal_order_type_display') status_of_prior_restraint = serializers.CharField(source='get_status_of_prior_restraint_display') - def __init__(self, *args, **kwargs): - request = kwargs.get('context', {}).get('request') - str_fields = request.GET.get('fields', '') if request else None - fields = str_fields.split(',') if str_fields else None - - super().__init__(*args, **kwargs) - - if fields is not None: - # Drop any fields that are not specified in the `fields` - # argument. - allowed = set(fields) - existing = set(self.fields) - for field_name in existing - allowed: - self.fields.pop(field_name) - @extend_schema_field(OpenApiTypes.URI) def get_url(self, obj): if self.context.get('request'): @@ -241,3 +246,11 @@ class FlatIncidentSerializer(BaseIncidentSerializer): subpoena_statuses = FlatListField( child=ChoiceField(choices.SUBPOENA_STATUS) ) + + +class CSVIncidentSerializer(VariableFieldSerializer): + title = serializers.CharField() + date = serializers.DateField() + url = serializers.CharField() + tags = serializers.CharField(source='tag_summary') + categories = serializers.CharField(source='category_summary') diff --git a/incident/api/views.py b/incident/api/views.py index 7057ad85d..05178c002 100644 --- a/incident/api/views.py +++ b/incident/api/views.py @@ -8,6 +8,7 @@ OuterRef, Subquery, ) +from django.utils.functional import cached_property from rest_framework.decorators import action from drf_spectacular.utils import ( extend_schema, @@ -27,6 +28,7 @@ EquipmentSerializer, CategorySerializer, FlatIncidentSerializer, + CSVIncidentSerializer, ) from incident import models from incident.utils.incident_filter import IncidentFilter, get_openapi_parameters, DateFilter @@ -123,6 +125,30 @@ def dispatch(self, *args, **kwargs) -> 'HttpResponse': return response + @cached_property + def csv_format_is_requested(self): + return getattr(self.request, 'accepted_renderer', None) and \ + self.request.accepted_renderer.format == 'csv' + + @cached_property + def requested_fields(self): + if fields := self.request.GET.get('fields'): + return set(map(str.strip, fields.split(','))) + else: + return set() + + def get_serializer_context(self): + context = super().get_serializer_context() + context['requested_fields'] = self.requested_fields + return context + + @cached_property + def can_apply_csv_serializer(self): + if self.requested_fields: + return self.requested_fields <= CSVIncidentSerializer().fields.keys() + else: + return False + def get_renderer_context(self): context = super().get_renderer_context() @@ -135,8 +161,13 @@ def get_renderer_context(self): return context def get_serializer_class(self): - if getattr(self.request, 'accepted_renderer', None) and self.request.accepted_renderer.format == 'csv': - return FlatIncidentSerializer + if self.csv_format_is_requested: + # If it's possible to use the faster CSV serializer, use + # that. + if self.can_apply_csv_serializer: + return CSVIncidentSerializer + else: + return FlatIncidentSerializer return super().get_serializer_class() def paginate_queryset(self, queryset): @@ -145,6 +176,30 @@ def paginate_queryset(self, queryset): return super().paginate_queryset(queryset) def get_queryset(self): + if self.csv_format_is_requested and self.can_apply_csv_serializer: + annotated_fields = { + 'categories': 'category_summary', + 'tags': 'tag_summary', + 'url': 'url', + } + result_fields = [] + annotations = [] + + requested_fields = self.requested_fields + serializer = self.get_serializer() + valid_fields = requested_fields.intersection(set(serializer.fields)) + for field in valid_fields: + result_fields.append( + annotated_fields.get(field, field) + ) + annotations = { + v for k, v in annotated_fields.items() + if k in requested_fields + } + return models.IncidentPage.objects.live()\ + .for_csv(annotations, self.request)\ + .values(*result_fields) + incident_filter = IncidentFilter(self.request.GET) incidents = incident_filter.get_queryset() diff --git a/incident/models/incident_page.py b/incident/models/incident_page.py index b2f372a60..b397ddfca 100644 --- a/incident/models/incident_page.py +++ b/incident/models/incident_page.py @@ -14,13 +14,13 @@ Value, When, ) -from django.db.models.functions import ExtractDay, Cast, Trunc, TruncMonth, Coalesce +from django.db.models.functions import ExtractDay, Cast, Trunc, TruncMonth, Coalesce, Concat from django.utils.functional import cached_property from django.utils.html import strip_tags from django.template.defaultfilters import truncatewords from modelcluster.fields import ParentalManyToManyField, ParentalKey from django.contrib.postgres.fields import ArrayField -from django.contrib.postgres.aggregates import ArrayAgg +from django.contrib.postgres.aggregates import ArrayAgg, StringAgg from psycopg2.extras import DateRange from wagtail.admin.panels import ( FieldPanel, @@ -91,6 +91,47 @@ def formfield(self, **kwargs): class IncidentQuerySet(PageQuerySet): """A QuerySet for incident pages that incorporates update data""" + def for_csv(self, with_annotations, request): + # TODO: if 'url' in with_annotations, get the actually correct + # base URI for the incident index page. + base_uri = request.build_absolute_uri('/all-incidents/') + available_annotations = { + 'tag_summary': Subquery( + IncidentPage.objects.only('tags').annotate( + tag_summary=StringAgg( + 'tags__title', + delimiter=', ', + ordering=('tags__title',) + ) + ).filter( + pk=OuterRef('pk') + ).values('tag_summary'), + output_field=models.CharField() + ), + 'category_summary': Subquery( + IncidentPage.objects.only('categories').annotate( + category_summary=StringAgg( + 'categories__category__title', + delimiter=', ', + ordering=('categories__category__title',) + ) + ).filter( + pk=OuterRef('pk') + ).values('category_summary'), + output_field=models.CharField(), + ), + 'url': Concat( + Value(base_uri), + "slug", + output_field=models.CharField() + ), + } + annotations_to_apply = { + label: expression for label, expression in available_annotations.items() + if label in with_annotations + } + return self.annotate(**annotations_to_apply) + def with_public_associations(self): """Prefetch and select related data for public consumption @@ -1030,3 +1071,34 @@ def get_all_targets_for_display(self): for institution in self.targeted_institutions.all(): items.append(f'{institution.title}') return ', '.join(items) + + +# CSV_ANNOTATIONS = { +# 'tag_summary': Subquery( +# IncidentPage.objects.only('tags').annotate( +# tag_summary=StringAgg( +# 'tags__title', +# delimiter=', ', +# ordering=('tags__title',) +# ) +# ).filter( +# pk=OuterRef('pk') +# ).values('tag_summary'), +# output_field=models.CharField(), +# ), +# 'category_summary': Subquery( +# IncidentPage.objects.only('categories').annotate( +# tag_summary=StringAgg( +# 'categories__category__title', +# delimiter=', ', +# ordering=('categories__category__title',) +# ) +# ).filter( +# pk=OuterRef('pk') +# ).values('category_summary'), +# output_field=models.CharField(), +# ), +# # 'authors': models.IncidentPage.only('authors').annotate( + +# # ) +# }