diff --git a/corehq/apps/data_interfaces/deduplication.py b/corehq/apps/data_interfaces/deduplication.py index ba068f1ac340..4eb1d080bf95 100644 --- a/corehq/apps/data_interfaces/deduplication.py +++ b/corehq/apps/data_interfaces/deduplication.py @@ -46,12 +46,7 @@ def apply_criterion_to_query(query_, definition): queries.MUST_NOT, ) elif isinstance(definition, LocationFilterDefinition): - # Get all users owning cases at definition.location_id - owners_ids = user_ids_at_locations([definition.location_id]) - # Add the definition.location_id for cases which belong to definition.location_id - owners_ids.append(definition.location_id) - - query_ = query_.owner(owners_ids) + query_ = query_.owner(_get_owner_ids_for_definition(definition)) return query_ @@ -61,6 +56,15 @@ def apply_criterion_to_query(query_, definition): return query +def _get_owner_ids_for_definition(location_definition): + location_ids = location_definition.get_location_ids() + owner_ids = user_ids_at_locations(location_ids) + # Add the location_ids for cases which belong to a location + owner_ids.extend(location_ids) + + return owner_ids + + def case_exists_in_es( domain, case, diff --git a/corehq/apps/data_interfaces/models.py b/corehq/apps/data_interfaces/models.py index 3e745abd4b58..055487257a5c 100644 --- a/corehq/apps/data_interfaces/models.py +++ b/corehq/apps/data_interfaces/models.py @@ -755,6 +755,14 @@ def is_matching_location(location_id): def location(self): return SQLLocation.by_location_id(self.location_id) + def get_location_ids(self): + if self.include_child_locations: + location_ids = list(SQLLocation.objects.get_locations_and_children_ids([self.location_id])) + else: + location_ids = [self.location_id] + + return location_ids + def to_dict(self): return { 'location_id': self.location_id, diff --git a/corehq/apps/data_interfaces/tests/test_case_deduplication.py b/corehq/apps/data_interfaces/tests/test_case_deduplication.py index a4649a953ab9..cddd1e8f74b3 100644 --- a/corehq/apps/data_interfaces/tests/test_case_deduplication.py +++ b/corehq/apps/data_interfaces/tests/test_case_deduplication.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from itertools import chain -from unittest.mock import patch +from unittest.mock import patch, PropertyMock from django.test import TestCase from freezegun import freeze_time @@ -35,6 +35,7 @@ from corehq.apps.users.tasks import tag_cases_as_deleted_and_remove_indices from corehq.form_processor.models import CommCareCase from corehq.util.test_utils import flag_enabled, set_parent_case +from corehq.apps.locations.models import SQLLocation, LocationType @es_test(requires=[case_search_adapter]) @@ -77,6 +78,7 @@ def test_with_location_filter(self): # Create a filter criteria of cases to consider definition = LocationFilterDefinition.objects.create( location_id='mustafar_id', + include_child_locations=False, ) criteria = CaseRuleCriteria(rule=rule) criteria.definition = definition @@ -97,6 +99,50 @@ def test_with_location_filter(self): self.assertTrue(retrieved_cases[0]['owner_id'] == location_id) self.assertTrue(retrieved_cases[1]['owner_id'] == location_id) + @patch.object(LocationType, 'commtrack_enabled', new_callable=PropertyMock, return_value=False) + @patch('corehq.apps.commtrack.models.sync_supply_point') + def test_with_child_location(self, sync_supply_mock, mock_commtrack_enabled): + loc_type = LocationType.objects.create(domain='test-domain', name='level1') + parent_location = SQLLocation.objects.create( + domain='test-domain', name='parent', location_type=loc_type, location_id='parent' + ) + SQLLocation.objects.create( + domain='test-domain', name='child', location_type=loc_type, parent=parent_location, location_id='child' + ) + + cases = [ + self.factory.create_case(case_name=case_name, update={'dob': dob}) for (case_name, dob) in [ + ("Anakin Skywalker", "1977-03-25"), + ("Darth Vadar", "1977-03-25"), + ("Wannabe Anakin Skywalker", "1977-03-25"), + ("Wannabe Darth Vadar", "1977-03-25"), + ] + ] + + rule = self.create_rule('test rule', cases[0].type) + definition = LocationFilterDefinition.objects.create( + location_id='parent', + include_child_locations=True, + ) + criteria = CaseRuleCriteria(rule=rule) + criteria.definition = definition + criteria.save() + + location_id = 'child' + + # Only assign location id to first 2 cases, since we want only those two cases to be considered + cases[0].owner_id = location_id + cases[1].owner_id = location_id + + self._prime_es_index(cases) + + query = _get_es_filtered_case_query(self.domain, cases[0], rule.memoized_criteria) + retrieved_cases = query.run().hits + + self.assertEqual(len(retrieved_cases), 2) + self.assertTrue(retrieved_cases[0]['owner_id'] == location_id) + self.assertTrue(retrieved_cases[1]['owner_id'] == location_id) + def test_with_case_properties_filter_match_equal(self): match_type = MatchPropertyDefinition.MATCH_EQUAL