diff --git a/course_discovery/apps/api/v2/tests/test_views/test_search.py b/course_discovery/apps/api/v2/tests/test_views/test_search.py index 69348ab95a..7a2385df72 100644 --- a/course_discovery/apps/api/v2/tests/test_views/test_search.py +++ b/course_discovery/apps/api/v2/tests/test_views/test_search.py @@ -1,6 +1,7 @@ -""" Test cases for api/v2/search/all """ +""" Test cases for api/v2/search/all endpoint """ import json +from urllib.parse import parse_qs, urlparse import ddt from django.urls import reverse @@ -17,11 +18,16 @@ class AggregateSearchViewSetV2Tests(mixins.LoginMixin, ElasticsearchTestMixin, mixins.APITestCase): list_path = reverse("api:v2:search-all-list") - def fetch_page_data(self, page_size, search_after=None): - query_params = {"page_size": page_size} - if search_after: - query_params["search_after"] = search_after - response = self.client.get(self.list_path, data=query_params) + def fetch_page_data(self, page_size, next_url=None): + """ + Fetch a page of data using the provided page size and next_url. + If next_url is not provided, fetch the first page. + """ + if next_url: + response = self.client.get(next_url) + else: + query_params = {"page_size": page_size} + response = self.client.get(self.list_path, data=query_params) assert response.status_code == 200 return response.json() @@ -96,20 +102,27 @@ def test_search_after_pagination(self): self.validate_page_data(response_data, page_size) all_results = response_data["results"] - next_token = response_data.get("next") + next_url = response_data.get("next") + + while next_url: + # Parse the `search_after` value from the next_url query params + parsed_url = urlparse(next_url) + query_params = parse_qs(parsed_url.query) + search_after = query_params.get("search_after", [None])[0] + assert search_after is not None, "'search_after' parameter is missing in the next_url" - while next_token: - response_data = self.fetch_page_data(page_size, search_after=json.dumps(next_token)) + last_sort_value = all_results[-1]["sort"] + assert ( + json.loads(search_after) == last_sort_value + ), "The 'search_after' value in the next_url does not match the 'sort' field of the last result" + + response_data = self.fetch_page_data(page_size, next_url=next_url) expected_size = min(page_size, 75 - len(all_results)) self.validate_page_data(response_data, expected_size) all_results.extend(response_data["results"]) - next_token = response_data.get("next") - - if next_token: - last_sort_value = response_data["results"][-1]["sort"] - assert last_sort_value == next_token + next_url = response_data.get("next") assert len(all_results) == 75, "The total number of results does not match the expected count" diff --git a/course_discovery/apps/edx_elasticsearch_dsl_extensions/viewsets.py b/course_discovery/apps/edx_elasticsearch_dsl_extensions/viewsets.py index 56e543c21e..c6e3f4fc5b 100644 --- a/course_discovery/apps/edx_elasticsearch_dsl_extensions/viewsets.py +++ b/course_discovery/apps/edx_elasticsearch_dsl_extensions/viewsets.py @@ -5,6 +5,7 @@ from django_elasticsearch_dsl_drf.pagination import PageNumberPagination from django_elasticsearch_dsl_drf.viewsets import DocumentViewSet as OriginDocumentViewSet from rest_framework.permissions import IsAuthenticated +from rest_framework.utils.urls import replace_query_param from course_discovery.apps.api import mixins from course_discovery.apps.edx_elasticsearch_dsl_extensions.backends import MultiMatchSearchFilterBackend @@ -107,13 +108,14 @@ class SearchAfterPagination(PageNumberPagination): """ page_size_query_param = "page_size" + search_after_param = "search_after" def paginate_queryset(self, queryset, request, view=None): """ Paginate the Elasticsearch queryset using search_after. """ - search_after = request.query_params.get("search_after") + search_after = request.query_params.get(self.search_after_param) if search_after: try: queryset = queryset.extra(search_after=json.loads(search_after)) @@ -122,16 +124,20 @@ def paginate_queryset(self, queryset, request, view=None): return super().paginate_queryset(queryset, request, view) - def get_paginated_response(self, data): - """ - Get paginated response, including search_after value for the next page. - """ - response = super().get_paginated_response(data) - last_item = data[-1] if data else None - search_after = last_item.get("sort") if last_item else None - next_link = response.data.pop("next", None) - response.data["next"] = search_after if next_link else None - return response + def get_next_link(self): + if not self.page.has_next(): + return None + + last_item_sort = self._get_last_item_sort() + if not last_item_sort: + return None + + url = self.request.build_absolute_uri() + return replace_query_param(url, self.search_after_param, json.dumps(last_item_sort)) + + def _get_last_item_sort(self): + last_item = self.page.object_list[-1] if self.page.object_list else None + return list(last_item.meta.sort) if last_item else None class BaseElasticsearchDocumentViewSet(mixins.DetailMixin, mixins.FacetMixin, DocumentViewSet):