Skip to content

Commit

Permalink
Initial implementation of faster CSV API view
Browse files Browse the repository at this point in the history
  • Loading branch information
chigby committed May 11, 2023
1 parent 9752e26 commit 3a9f506
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 20 deletions.
45 changes: 29 additions & 16 deletions incident/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,27 @@ def get_url(self, obj):
return obj.get_full_url()


class BaseIncidentSerializer(serializers.Serializer):
class VariableFieldSerializer(serializers.Serializer):
"""A serializer that takes a set of field names its context with
the key `requested_fields` that controls what fields should be
returned.
"""

def __init__(self, *args, **kwargs):
requested_fields = kwargs.get('context', {}).get('requested_fields', set())

super().__init__(*args, **kwargs)

if requested_fields:
# Drop any fields that are not specified in
# `requested_fields`.
existing = set(self.fields)
for field_name in existing - requested_fields:
self.fields.pop(field_name)


class BaseIncidentSerializer(VariableFieldSerializer):
title = serializers.CharField()
url = serializers.SerializerMethodField()
first_published_at = serializers.DateTimeField()
Expand Down Expand Up @@ -153,21 +173,6 @@ class BaseIncidentSerializer(serializers.Serializer):
legal_order_type = serializers.CharField(source='get_legal_order_type_display')
status_of_prior_restraint = serializers.CharField(source='get_status_of_prior_restraint_display')

def __init__(self, *args, **kwargs):
request = kwargs.get('context', {}).get('request')
str_fields = request.GET.get('fields', '') if request else None
fields = str_fields.split(',') if str_fields else None

super().__init__(*args, **kwargs)

if fields is not None:
# Drop any fields that are not specified in the `fields`
# argument.
allowed = set(fields)
existing = set(self.fields)
for field_name in existing - allowed:
self.fields.pop(field_name)

@extend_schema_field(OpenApiTypes.URI)
def get_url(self, obj):
if self.context.get('request'):
Expand Down Expand Up @@ -241,3 +246,11 @@ class FlatIncidentSerializer(BaseIncidentSerializer):
subpoena_statuses = FlatListField(
child=ChoiceField(choices.SUBPOENA_STATUS)
)


class CSVIncidentSerializer(VariableFieldSerializer):
title = serializers.CharField()
date = serializers.DateField()
url = serializers.CharField()
tags = serializers.CharField(source='tag_summary')
categories = serializers.CharField(source='category_summary')
59 changes: 57 additions & 2 deletions incident/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
OuterRef,
Subquery,
)
from django.utils.functional import cached_property
from rest_framework.decorators import action
from drf_spectacular.utils import (
extend_schema,
Expand All @@ -27,6 +28,7 @@
EquipmentSerializer,
CategorySerializer,
FlatIncidentSerializer,
CSVIncidentSerializer,
)
from incident import models
from incident.utils.incident_filter import IncidentFilter, get_openapi_parameters, DateFilter
Expand Down Expand Up @@ -123,6 +125,30 @@ def dispatch(self, *args, **kwargs) -> 'HttpResponse':

return response

@cached_property
def csv_format_is_requested(self):
return getattr(self.request, 'accepted_renderer', None) and \
self.request.accepted_renderer.format == 'csv'

@cached_property
def requested_fields(self):
if fields := self.request.GET.get('fields'):
return set(map(str.strip, fields.split(',')))
else:
return set()

def get_serializer_context(self):
context = super().get_serializer_context()
context['requested_fields'] = self.requested_fields
return context

@cached_property
def can_apply_csv_serializer(self):
if self.requested_fields:
return self.requested_fields <= CSVIncidentSerializer().fields.keys()
else:
return False

def get_renderer_context(self):
context = super().get_renderer_context()

Expand All @@ -135,8 +161,13 @@ def get_renderer_context(self):
return context

def get_serializer_class(self):
if getattr(self.request, 'accepted_renderer', None) and self.request.accepted_renderer.format == 'csv':
return FlatIncidentSerializer
if self.csv_format_is_requested:
# If it's possible to use the faster CSV serializer, use
# that.
if self.can_apply_csv_serializer:
return CSVIncidentSerializer
else:
return FlatIncidentSerializer
return super().get_serializer_class()

def paginate_queryset(self, queryset):
Expand All @@ -145,6 +176,30 @@ def paginate_queryset(self, queryset):
return super().paginate_queryset(queryset)

def get_queryset(self):
if self.csv_format_is_requested and self.can_apply_csv_serializer:
annotated_fields = {
'categories': 'category_summary',
'tags': 'tag_summary',
'url': 'url',
}
result_fields = []
annotations = []

requested_fields = self.requested_fields
serializer = self.get_serializer()
valid_fields = requested_fields.intersection(set(serializer.fields))
for field in valid_fields:
result_fields.append(
annotated_fields.get(field, field)
)
annotations = {
v for k, v in annotated_fields.items()
if k in requested_fields
}
return models.IncidentPage.objects.live()\
.for_csv(annotations, self.request)\
.values(*result_fields)

incident_filter = IncidentFilter(self.request.GET)
incidents = incident_filter.get_queryset()

Expand Down
76 changes: 74 additions & 2 deletions incident/models/incident_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
Value,
When,
)
from django.db.models.functions import ExtractDay, Cast, Trunc, TruncMonth, Coalesce
from django.db.models.functions import ExtractDay, Cast, Trunc, TruncMonth, Coalesce, Concat
from django.utils.functional import cached_property
from django.utils.html import strip_tags
from django.template.defaultfilters import truncatewords
from modelcluster.fields import ParentalManyToManyField, ParentalKey
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.aggregates import ArrayAgg, StringAgg
from psycopg2.extras import DateRange
from wagtail.admin.panels import (
FieldPanel,
Expand Down Expand Up @@ -91,6 +91,47 @@ def formfield(self, **kwargs):

class IncidentQuerySet(PageQuerySet):
"""A QuerySet for incident pages that incorporates update data"""
def for_csv(self, with_annotations, request):
# TODO: if 'url' in with_annotations, get the actually correct
# base URI for the incident index page.
base_uri = request.build_absolute_uri('/all-incidents/')
available_annotations = {
'tag_summary': Subquery(
IncidentPage.objects.only('tags').annotate(
tag_summary=StringAgg(
'tags__title',
delimiter=', ',
ordering=('tags__title',)
)
).filter(
pk=OuterRef('pk')
).values('tag_summary'),
output_field=models.CharField()
),
'category_summary': Subquery(
IncidentPage.objects.only('categories').annotate(
category_summary=StringAgg(
'categories__category__title',
delimiter=', ',
ordering=('categories__category__title',)
)
).filter(
pk=OuterRef('pk')
).values('category_summary'),
output_field=models.CharField(),
),
'url': Concat(
Value(base_uri),
"slug",
output_field=models.CharField()
),
}
annotations_to_apply = {
label: expression for label, expression in available_annotations.items()
if label in with_annotations
}
return self.annotate(**annotations_to_apply)

def with_public_associations(self):
"""Prefetch and select related data for public consumption
Expand Down Expand Up @@ -1030,3 +1071,34 @@ def get_all_targets_for_display(self):
for institution in self.targeted_institutions.all():
items.append(f'{institution.title}')
return ', '.join(items)


# CSV_ANNOTATIONS = {
# 'tag_summary': Subquery(
# IncidentPage.objects.only('tags').annotate(
# tag_summary=StringAgg(
# 'tags__title',
# delimiter=', ',
# ordering=('tags__title',)
# )
# ).filter(
# pk=OuterRef('pk')
# ).values('tag_summary'),
# output_field=models.CharField(),
# ),
# 'category_summary': Subquery(
# IncidentPage.objects.only('categories').annotate(
# tag_summary=StringAgg(
# 'categories__category__title',
# delimiter=', ',
# ordering=('categories__category__title',)
# )
# ).filter(
# pk=OuterRef('pk')
# ).values('category_summary'),
# output_field=models.CharField(),
# ),
# # 'authors': models.IncidentPage.only('authors').annotate(

# # )
# }

0 comments on commit 3a9f506

Please sign in to comment.