Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add associated document id field to ChangeMeta #33807

Merged
merged 10 commits into from
Dec 5, 2023
6 changes: 1 addition & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ on:
branches:
- master
- hotfix-deploy
- ap/sql-repeater/phase-2
- mjr/erm-update-rules
- mjr/erm-fixtures
- mjr/erm-custom-roles
- mjr/erm-roles
- mjr/add-change-meta-context
schedule:
# see corehq/apps/hqadmin/management/commands/static_analysis.py
- cron: '47 12 * * *'
Expand Down
24 changes: 22 additions & 2 deletions corehq/apps/data_interfaces/pillow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from corehq.apps.data_interfaces.deduplication import is_dedupe_xmlns
from corehq.apps.data_interfaces.models import AutomaticUpdateRule
from corehq.form_processor.exceptions import XFormNotFound
from corehq.form_processor.models import CommCareCase
from corehq.form_processor.models.forms import XFormInstance
from corehq.toggles import CASE_DEDUPE
from corehq.util.soft_assert import soft_assert


class CaseDeduplicationProcessor(PillowProcessor):
Expand All @@ -29,14 +32,15 @@ def process_change(self, change):
if change.deleted:
return

if is_dedupe_xmlns(change.get_document().get('xmlns')):
associated_form = self._get_associated_form(change)
if not associated_form or is_dedupe_xmlns(associated_form.xmlns):
return

rules = self._get_rules(domain)
if not rules:
return

for case_update in get_case_updates(change.get_document()):
for case_update in get_case_updates(associated_form, for_case=change.id):
self._process_case_update(domain, case_update)

def _get_rules(self, domain):
Expand All @@ -54,3 +58,19 @@ def _process_action(self, domain, rule, action, changed_properties, case_id):
case = CommCareCase.objects.get_case(case_id, domain)
if case.type == rule.case_type:
rule.run_rule(case, datetime.utcnow())

def _get_associated_form(self, change):
associated_form_id = change.metadata.associated_document_id
associated_form = None
if associated_form_id:
try:
associated_form = XFormInstance.objects.get_form(associated_form_id)
except XFormNotFound:
_assert = soft_assert(['mriley_at_dimagi_dot_com'.replace('_at_', '@').replace('_dot_', '.')])
_assert(False, 'Associated form not found', {
'case_id': change.id,
'form_id': associated_form_id
})
associated_form = None

return associated_form
8 changes: 4 additions & 4 deletions corehq/apps/data_interfaces/tests/test_case_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from corehq.apps.users.models import CommCareUser
from corehq.apps.users.tasks import tag_cases_as_deleted_and_remove_indices
from corehq.form_processor.models import CommCareCase, XFormInstance
from corehq.pillows.xform import get_xform_pillow
from corehq.pillows.case import get_case_pillow
from corehq.util.test_utils import flag_enabled, set_parent_case


Expand Down Expand Up @@ -714,10 +714,10 @@ def setUpClass(cls):
cls.domain = 'naboo'
cls.case_type = 'people'
cls.factory = CaseFactory(cls.domain)
cls.pillow = get_xform_pillow(skip_ucr=True)
cls.pillow = get_case_pillow(skip_ucr=True)

def setUp(self):
self.kafka_offset = get_topic_offset(topics.FORM_SQL)
self.kafka_offset = get_topic_offset(topics.CASE_SQL)

@patch("corehq.apps.data_interfaces.models.find_duplicate_case_ids")
def test_pillow_processes_changes(self, find_duplicate_cases_mock):
Expand All @@ -729,7 +729,7 @@ def test_pillow_processes_changes(self, find_duplicate_cases_mock):

find_duplicate_cases_mock.return_value = [case1.case_id, case2.case_id]

new_kafka_sec = get_topic_offset(topics.FORM_SQL)
new_kafka_sec = get_topic_offset(topics.CASE_SQL)
self.pillow.process_changes(since=self.kafka_offset, forever=False)

self._assert_case_duplicate_pair(case1.case_id, [case2.case_id])
Expand Down
2 changes: 1 addition & 1 deletion corehq/apps/hqcase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def bulk_update_cases(domain, case_changes, device_id, xmlns=None):

def resave_case(domain, case, send_post_save_signal=True):
from corehq.form_processor.change_publishers import publish_case_saved
publish_case_saved(case, send_post_save_signal)
publish_case_saved(case, send_post_save_signal=send_post_save_signal)


def get_last_non_blank_value(case, case_property):
Expand Down
82 changes: 82 additions & 0 deletions corehq/ex-submodules/casexml/apps/case/tests/test_xform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from django.test import SimpleTestCase
from casexml.apps.case.xml import V2
from casexml.apps.case.xform import get_case_updates
from casexml.apps.case.xml.parser import CaseUpdate


class TestGetCaseUpdates(SimpleTestCase):
default_case_id = '1111'
default_user_id = '2222'
default_modified_time = '2023-11-28T15:26:55.859000Z'

def test_processes_single_case(self):
case_block = self._create_case_block(
case_id='case1',
user_id='abc',
modified_on='2023-11-28T15:26:55.859000Z',
create_block={'case_name': 'test', 'case_type': 'test_type'}
)
xform = {
'case': case_block
}

updates = get_case_updates(xform)
expected_case = self._create_case_update(
case_id='case1',
user_id='abc',
modified_on='2023-11-28T15:26:55.859000Z',
create_block={'case_name': 'test', 'case_type': 'test_type'}
)
self.assertEqual(expected_case, updates[0])

def test_processes_sub_case(self):
case1 = self._create_case_block(case_id='1')
case2 = self._create_case_block(case_id='2')
xform = {
'case': case1,
'sub_case': {
'case': case2
}
}

updates = get_case_updates(xform)
self.assertEqual(updates, [self._create_case_update(case_id='1'), self._create_case_update(case_id='2')])

def test_can_restrict_by_id(self):
case1 = self._create_case_block(case_id='1')
case2 = self._create_case_block(case_id='2')
xform = {
'case': case1,
'sub_case': {
'case': case2
}
}

updates = get_case_updates(xform, for_case='1')
self.assertEqual(updates, [self._create_case_update(case_id='1')])

def _create_case_block(
self, case_id=None, user_id=None, modified_on=None, create_block=None, update_block=None):
block = {
'@case_id': case_id or self.default_case_id,
'@date_modified': modified_on or self.default_modified_time,
'@user_id': user_id or self.default_user_id,
'@xmlns': 'http://commcarehq.org/case/transaction/v2',
}

if create_block:
block['create'] = create_block

if update_block:
block['update'] = update_block

return block

def _create_case_update(
self, case_id=None, user_id=None, modified_on=None, create_block=None, update_block=None):
block = self._create_case_block(case_id, user_id, modified_on, create_block, update_block)

return CaseUpdate(
case_id or self.default_case_id, V2, block,
user_id=(user_id or self.default_user_id),
modified_on_str=modified_on or self.default_modified_time)
42 changes: 41 additions & 1 deletion corehq/ex-submodules/casexml/apps/case/tests/xml/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from casexml.apps.case.xml.parser import CaseUpdate
from casexml.apps.case.xml.parser import CaseUpdate, CaseCreateAction
from casexml.apps.case.xml import V2
from django.test import SimpleTestCase

Expand Down Expand Up @@ -40,6 +40,29 @@ def test_get_normalized_updates(self):
self.assertEqual(case_update.get_normalized_update_property_names(),
{'name', 'owner_id', 'type'})

def test_equality(self):
create_block = {
'case_name': 'test_case',
'owner_id': '12345',
'case_type': 'test_case_type'
}
case_block = self._create_case_block(create_block)

case_update_1 = CaseUpdate('case_id', V2, case_block)
case_update_2 = CaseUpdate('case_id', V2, case_block)
self.assertEqual(case_update_1, case_update_2)

def test_non_equality(self):
create_block = {
'case_name': 'test_case',
'owner_id': '12345',
'case_type': 'test_case_type'
}
case_block = self._create_case_block(create_block)
case_update_1 = CaseUpdate('case_id', V2, case_block)
case_update_2 = CaseUpdate('case_id2', V2, case_block)
self.assertNotEqual(case_update_1, case_update_2)

def _create_case_block(self, create_block=None, update_block=None):
block = {
'@case_id': '1111',
Expand All @@ -55,3 +78,20 @@ def _create_case_block(self, create_block=None, update_block=None):
block['update'] = update_block

return block


class CaseActionTests(SimpleTestCase):
def test_equality(self):
block = {
'case_name': 'test'
}
action1 = CaseCreateAction(block)
action2 = CaseCreateAction(block)

self.assertEqual(action1, action2)

def test_non_equality(self):
action1 = CaseCreateAction({'case_name': 'one'})
action2 = CaseCreateAction({'case_name': 'two'})

self.assertNotEqual(action1, action2)
20 changes: 14 additions & 6 deletions corehq/ex-submodules/casexml/apps/case/xform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,22 @@ def _extract_case_blocks(data, path=None, form_id=Ellipsis):
yield from _extract_case_blocks(value, new_path, form_id=form_id)


def get_case_updates(xform):
def get_case_updates(xform, for_case=None):
if not xform:
return []
updates = sorted(
[case_update_from_block(cb) for cb in extract_case_blocks(xform)],
key=lambda update: update.id
)
by_case_id = groupby(updates, lambda update: update.id)

updates = [case_update_from_block(cb) for cb in extract_case_blocks(xform)]

if for_case:
updates = [update for update in updates if update.id == for_case]
by_case_id = [(for_case, updates)]
else:
updates = sorted(
updates,
key=lambda update: update.id
)
by_case_id = groupby(updates, lambda update: update.id)

return list(itertools.chain(
*[order_updates(updates) for case_id, updates in by_case_id]
))
Expand Down
17 changes: 17 additions & 0 deletions corehq/ex-submodules/casexml/apps/case/xml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def get_known_properties(self):
def __repr__(self):
return f"{type(self).__name__}(block={self.raw_block!r})"

def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)

def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def _from_block_and_mapping(cls, block, mapping):
def _normalize(val):
Expand Down Expand Up @@ -345,6 +352,16 @@ def has_attachments(self):
def __str__(self):
return "%s: %s" % (self.version, self.id)

def __repr__(self):
return str(self.__dict__)

def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)

def __ne__(self, other):
return not self.__eq__(other)

def _filtered_action(self, func):
# filters the actions, assumes exactly 0 or 1 match.
filtered = list(filter(func, self.actions))
Expand Down
9 changes: 6 additions & 3 deletions corehq/ex-submodules/pillowtop/feed/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class ChangeMeta(jsonobject.JsonObject):
# track when first published (will not get updated on retry, unlike publish_timestamp)
original_publication_datetime = jsonobject.DateTimeProperty(default=datetime.utcnow)

# available to hold any associated document. For cases, this is the form ID responsible for the change
associated_document_id = jsonobject.StringProperty()


class Change(object):
"""
Expand Down Expand Up @@ -119,10 +122,10 @@ def __getitem__(self, key):
return self._dict[key]

def __setitem__(self, key, value):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def __delitem__(self, key, value):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def __iter__(self):
return iter(self._dict)
Expand All @@ -134,7 +137,7 @@ def get(self, key, default=None):
return self._dict.get(key, default)

def pop(self, key, default):
raise NotImplemented('This is a read-only dictionary!')
raise NotImplementedError('This is a read-only dictionary!')

def to_dict(self):
return self._dict
Expand Down
6 changes: 5 additions & 1 deletion corehq/form_processor/backends/sql/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def publish_changes_to_kafka(processed_forms, cases, stock_result):
publish_form_saved(processed_forms.submitted)
cases = cases or []
for case in cases:
publish_case_saved(case, send_post_save_signal=False)
publish_case_saved(
case,
associated_form_id=processed_forms.submitted.form_id,
send_post_save_signal=False
)

if stock_result:
for ledger in stock_result.models_to_save:
Expand Down
7 changes: 4 additions & 3 deletions corehq/form_processor/change_publishers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def publish_form_deleted(domain, form_id):
))


def publish_case_saved(case, send_post_save_signal=True):
def publish_case_saved(case, associated_form_id=None, send_post_save_signal=True):
"""
Publish the change to kafka and run case post-save signals.
"""
producer.send_change(topics.CASE_SQL, change_meta_from_sql_case(case))
producer.send_change(topics.CASE_SQL, change_meta_from_sql_case(case, associated_form_id))
if send_post_save_signal:
sql_case_post_save.send(case.__class__, case=case)


def change_meta_from_sql_case(case):
def change_meta_from_sql_case(case, associated_form_id=None):
return ChangeMeta(
document_id=case.case_id,
data_source_type=data_sources.SOURCE_SQL,
Expand All @@ -49,6 +49,7 @@ def change_meta_from_sql_case(case):
document_subtype=case.type,
domain=case.domain,
is_deletion=case.is_deleted,
associated_document_id=associated_form_id
)


Expand Down
3 changes: 3 additions & 0 deletions corehq/pillows/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_ucr_processor,
)
from corehq.form_processor.backends.sql.dbaccessors import CaseReindexAccessor
from corehq.apps.data_interfaces.pillow import CaseDeduplicationProcessor
from corehq.messaging.pillow import CaseMessagingSyncProcessor
from corehq.pillows.base import is_couch_change_for_sql_domain
from corehq.pillows.case_search import get_case_search_processor
Expand Down Expand Up @@ -119,6 +120,8 @@ def get_case_pillow(
processors = [case_to_es_processor, CaseMessagingSyncProcessor()]
if settings.RUN_CASE_SEARCH_PILLOW:
processors.append(case_search_processor)
if settings.RUN_DEDUPLICATION_PILLOW:
processors.append(CaseDeduplicationProcessor())
if not skip_ucr:
# this option is useful in tests to avoid extra UCR setup where unneccessary
processors = [ucr_processor, ucr_dr_processor] + processors
Expand Down
Loading
Loading