Skip to content

Commit

Permalink
unit test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-D-Akbar committed Jan 10, 2025
1 parent a70ebb7 commit b467ca7
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
6 changes: 4 additions & 2 deletions course_discovery/apps/course_metadata/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,19 +1155,19 @@ class SearchAfterMixin:
"""

@classmethod
def search(cls, query, queryset=None):
def search(cls, query, queryset=None, page_size=settings.ELASTICSEARCH_DSL_QUERYSET_PAGINATION):
"""
Queries the Elasticsearch index with optional pagination using `search_after`.
Args:
query (str) -- Elasticsearch querystring (e.g. `title:intro*`)
queryset (models.QuerySet) -- base queryset to search, defaults to objects.all()
page_size (int) -- Number of results per page.
Returns:
QuerySet
"""
query = clean_query(query)
page_size = 10000
queryset = queryset or cls.objects.all()

if query == '(*)':
Expand All @@ -1191,6 +1191,8 @@ def search(cls, query, queryset=None):
)

try:
import pdb
pdb.set_trace()
results = search.execute()
except RequestError as e:
logger.warning(f"Elasticsearch request failed: {e}")
Expand Down
81 changes: 79 additions & 2 deletions course_discovery/apps/course_metadata/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from decimal import Decimal
from functools import partial
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import ddt
import pytest
Expand All @@ -17,6 +17,7 @@
from django.core.exceptions import ValidationError
from django.core.management import call_command
from django.db import IntegrityError, transaction
from django.db.models import QuerySet
from django.test import TestCase, override_settings
from edx_django_utils.cache import RequestCache
from edx_toggles.toggles.testutils import override_waffle_switch
Expand All @@ -39,7 +40,8 @@
FAQ, AbstractHeadingBlurbModel, AbstractMediaModel, AbstractNamedModel, AbstractTitleDescriptionModel,
AbstractValueModel, CorporateEndorsement, Course, CourseEditor, CourseRun, CourseRunType, CourseType, Curriculum,
CurriculumCourseMembership, CurriculumCourseRunExclusion, CurriculumProgramMembership, DegreeCost, DegreeDeadline,
Endorsement, Organization, OrganizationMapping, Program, ProgramType, Ranking, Seat, SeatType, Subject, Topic
Endorsement, Organization, OrganizationMapping, Program, ProgramType, Ranking, SearchAfterMixin, Seat, SeatType,
Subject, Topic
)
from course_discovery.apps.course_metadata.publishers import (
CourseRunMarketingSitePublisher, ProgramMarketingSitePublisher
Expand Down Expand Up @@ -4192,3 +4194,78 @@ def test_basic(self):
self.assertEqual(course_run.restricted_run, restricted_course_run)
self.assertEqual(restricted_course_run.restriction_type, 'custom-b2b-enterprise')
self.assertEqual(str(restricted_course_run), "course-v1:SC+BreadX+3T2015: <custom-b2b-enterprise>")


class MockQuerySet(QuerySet):
def __init__(self, model=None, items=None):
self.model = model
self.items = items or []
super().__init__()

def filter(self, **kwargs):
import pdb
pdb.set_trace()
pk_in = kwargs.get("pk__in", [])
pk_in = set(map(int, pk_in))
print(f"Filtering with pk_in: {pk_in}, self.items: {[item.pk for item in self.items]}")
return MockQuerySet(model=self.model, items=[item for item in self.items if item.pk in pk_in])

def __iter__(self):
return iter(self.items)

def __len__(self):
return len(self.items)

def all(self):
return self

def _chain(self):
# Mimic Django's queryset chaining
return self.__class__(model=self.model, items=self.items)


class MockModel:
def __init__(self, pk):
self.pk = pk


class SearchAfterMixinTest(SearchAfterMixin, MockModel):
objects = MockQuerySet(model=MockModel)


class TestSearchAfterMixin(TestCase):
@patch("course_discovery.apps.course_metadata.models.registry.get_documents")
@patch("course_discovery.apps.course_metadata.models.logger")
@patch("course_discovery.apps.course_metadata.models.clean_query")
def test_search_with_mock_data(self, mock_clean_query, mock_logger, mock_registry):
mock_document = MagicMock()
mock_search = MagicMock()
mock_document.search.return_value = mock_search
mock_registry.return_value = (mock_document,)

mock_result1 = MagicMock()
mock_result1.pk = 1
mock_result1.meta.sort = ["sort1"]

mock_result2 = MagicMock()
mock_result2.pk = 2
mock_result2.meta.sort = ["sort2"]

mock_search.execute.side_effect = [
[mock_result1, mock_result2],
[],
]

mock_clean_query.return_value = "cleaned_query"

SearchAfterMixinTest.objects = MockQuerySet(
model=MockModel,
items=[MockModel(1), MockModel(2), MockModel(3)]
)

result_queryset = SearchAfterMixinTest.search("query")

self.assertEqual(len(result_queryset), 2)
self.assertTrue(all(item.pk in {1, 2} for item in result_queryset))
mock_logger.info.assert_called()
mock_registry.assert_called_once_with(models=(SearchAfterMixinTest,))

0 comments on commit b467ca7

Please sign in to comment.