From a45f986e2f0c47ec39dca878c7da1c486a04541c Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Sat, 9 Nov 2024 01:07:11 +0200 Subject: [PATCH 1/8] change test_hash_from_x --- .../signal_type/tests/test_hash_from_x.py | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py b/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py index 29ad6da74..ec1578079 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py @@ -1,7 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import unittest - +import pytest from threatexchange.signal_type import ( md5, raw_text, @@ -13,38 +12,39 @@ from threatexchange.signal_type.signal_base import TextHasher -class SignalTypeHashTest(unittest.TestCase): +# List of signal types to test +SIGNAL_TYPES_TO_TEST = [ + md5.VideoMD5Signal, + signal.PdqSignal, + raw_text.RawTextSignal, + trend_query.TrendQuerySignal, + url_md5.UrlMD5Signal, + url.URLSignal, +] + + +@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST) +def test_signal_names_unique(signal_type): + """ + Verify that each signal type has a unique name. + """ + name = signal_type.get_name() + seen = set() # Using a set to automatically manage unique entries + assert name not in seen, f"Two signal types share the same name: {signal_type!r} and {seen}" + seen.add(name) + + +@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST) +def test_signal_types_have_content(signal_type): """ - Sanity check for signal type hashing methods. + Ensure that each signal type has associated content types. """ + assert signal_type.get_content_types(), f"{signal_type!r} has no content types" + - # TODO - maybe make a metaclass for this to automatically detect? - SIGNAL_TYPES_TO_TEST = [ - md5.VideoMD5Signal, - signal.PdqSignal, - raw_text.RawTextSignal, - trend_query.TrendQuerySignal, - url_md5.UrlMD5Signal, - url.URLSignal, - ] - - def test_signal_names_unique(self): - seen = {} - for s in self.SIGNAL_TYPES_TO_TEST: - name = s.get_name() - assert ( - name not in seen - ), f"Two signal types share the same name: {s!r} and {seen[name]!r}" - - def test_signal_types_have_content(self): - for s in self.SIGNAL_TYPES_TO_TEST: - assert s.get_content_types(), "{s!r} has no content types" - - def test_str_hashers_have_impl(self): - text_hashers = [ - s for s in self.SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher) - ] - for s in text_hashers: - assert s.hash_from_str( - "test string" - ), "{s!r} produced no output from hasher" +@pytest.mark.parametrize("signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)]) +def test_str_hashers_have_impl(signal_type): + """ + Check that each TextHasher has an implementation that produces output. + """ + assert signal_type.hash_from_str("test string"), f"{signal_type!r} produced no output from hasher" From 890ac4bb1857da10bcd13892db7e7873f0f86c85 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Sat, 9 Nov 2024 01:09:45 +0200 Subject: [PATCH 2/8] change test_md5_hash --- .../signal_type/tests/test_md5_hash.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py b/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py index 88a12af58..2611572ce 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py @@ -1,23 +1,28 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import unittest import pathlib - +import pytest from threatexchange.signal_type.md5 import VideoMD5Signal +# Define the test file path TEST_FILE = pathlib.Path(__file__).parent.parent.parent.parent.joinpath( "data", "sample-b.jpg" ) +@pytest.fixture +def file_content(): + """ + Fixture to open and yield file content for testing, + then close the file after the test. + """ + with open(TEST_FILE, "rb") as f: + yield f.read() -class VideoMD5SignalTestCase(unittest.TestCase): - def setUp(self): - self.a_file = open(TEST_FILE, "rb") - - def tearDown(self): - self.a_file.close() - def test_can_hash_simple_files(self): - assert "d35c785545392755e7e4164457657269" == VideoMD5Signal.hash_from_bytes( - self.a_file.read() - ), "MD5 hash does not match" +def test_can_hash_simple_files(file_content): + """ + Test that the VideoMD5Signal produces the expected hash. + """ + expected_hash = "d35c785545392755e7e4164457657269" + computed_hash = VideoMD5Signal.hash_from_bytes(file_content) + assert computed_hash == expected_hash, "MD5 hash does not match" From a4ccca2c87bd3d312e6beade338dc431c7424cae Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Sat, 9 Nov 2024 01:14:54 +0200 Subject: [PATCH 3/8] change test_raw_text --- .../signal_type/tests/test_raw_text.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py index ede5c9f9a..e6e0e1808 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py @@ -1,5 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import pytest +from threatexchange.signal_type.raw_text import RawTextSignal from threatexchange.signal_type.tests.signal_type_test_helper import MatchesStrAutoTest from threatexchange.signal_type.raw_text import RawTextSignal @@ -8,16 +10,19 @@ class TestRawTextSignal(MatchesStrAutoTest): TYPE = RawTextSignal - def get_validate_hash_cases(self): + @pytest.fixture + def validate_hash_cases(self): return [ ("a", "a"), ("a ", "a"), ] - def get_compare_hash_cases(self): + @pytest.fixture + def compare_hash_cases(self): return [] - def get_matches_str_cases(self): + @pytest.fixture + def matches_str_cases(self): return [ ("", ""), ("a", "a"), @@ -30,3 +35,18 @@ def get_matches_str_cases(self): ("a" * 19, "a" * 18 + "b", False, 1), ("a" * 20, "a" * 19 + "b", True, 1), ] + + def test_validate_hash(self, validate_hash_cases): + for case in validate_hash_cases: + input_val, expected_hash = case + assert self.TYPE.validate_hash(input_val) == expected_hash + + def test_compare_hash(self, compare_hash_cases): + for case in compare_hash_cases: + input_val, expected_result = case + assert self.TYPE.compare_hash(input_val) == expected_result + + def test_matches_str(self, matches_str_cases): + for case in matches_str_cases: + input_str, match_str, expected_match, threshold = case + assert self.TYPE.matches_str(input_str, match_str, threshold) == expected_match From 485f8f7d02284f65cb888d93f6db8880f2b59075 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Sat, 9 Nov 2024 01:21:37 +0200 Subject: [PATCH 4/8] change test_url_md5_hash --- .../signal_type/tests/test_url_md5_hash.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py b/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py index f80443568..d361bf065 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py @@ -1,21 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import unittest - +import pytest from threatexchange.signal_type.url_md5 import UrlMD5Signal URL_TEST = "www.facebook.com/?user=123" FULL_URL_TEST = "https://www.facebook.com/?user=123" URL_TEST_MD5 = "e359430911fe80c2dd526d3cca21da30" +def test_can_hash_simple_url(): + assert UrlMD5Signal.hash_from_str(URL_TEST) == URL_TEST_MD5, "MD5 hash does not match" -class UrlMD5SignalTestCase(unittest.TestCase): - def test_can_hash_simple_url(self): - assert URL_TEST_MD5 == UrlMD5Signal.hash_from_str( - URL_TEST - ), "MD5 hash does not match" - - def test_can_hash_full_url(self): - assert URL_TEST_MD5 == UrlMD5Signal.hash_from_str( - FULL_URL_TEST - ), "MD5 hash does not match" +def test_can_hash_full_url(): + assert UrlMD5Signal.hash_from_str(FULL_URL_TEST) == URL_TEST_MD5, "MD5 hash does not match" From e71df225143ebbabd226688c62173c14c9e800d8 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Sat, 9 Nov 2024 01:24:12 +0200 Subject: [PATCH 5/8] change test_pdq_index --- .../signal_type/tests/test_pdq_index.py | 192 ++++++------------ 1 file changed, 62 insertions(+), 130 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py index e6bad4c7e..9922caacc 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py @@ -1,8 +1,6 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import unittest import pickle import typing as t +import pytest import functools from threatexchange.signal_type.index import ( @@ -13,145 +11,79 @@ test_entries = [ ( "0000000000000000000000000000000000000000000000000000000000000000", - dict( - { - "hash_type": "pdq", - "system_id": 9, - } - ), + {"hash_type": "pdq", "system_id": 9}, ), ( "000000000000000000000000000000000000000000000000000000000000ffff", - dict( - { - "hash_type": "pdq", - "system_id": 8, - } - ), + {"hash_type": "pdq", "system_id": 8}, ), ( "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - dict( - { - "hash_type": "pdq", - "system_id": 7, - } - ), + {"hash_type": "pdq", "system_id": 7}, ), ( "f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0", - dict( - { - "hash_type": "pdq", - "system_id": 6, - } - ), + {"hash_type": "pdq", "system_id": 6}, ), ( "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - dict( - { - "hash_type": "pdq", - "system_id": 5, - } - ), + {"hash_type": "pdq", "system_id": 5}, ), ] - -class TestPDQIndex(unittest.TestCase): - def setUp(self): - self.index = PDQIndex.build(test_entries) - - def assertEqualPDQIndexMatchResults( - self, result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch] - ): - self.assertEqual( - len(result), len(expected), "search results not of expected length" - ) - - accum_type = t.Dict[int, t.Set[int]] - - # Between python 3.8.6 and 3.8.11, something caused the order of results - # from the index to change. This was noticed for items which had the - # same distance. To allow for this, we convert result and expected - # arrays from - # [PDQIndexMatch, PDQIndexMatch] to { distance: } - # This allows you to compare [PDQIndexMatch A, PDQIndexMatch B] with - # [PDQIndexMatch B, PDQIndexMatch A] as long as A.distance == B.distance. - def quality_indexed_dict_reducer( - acc: accum_type, item: PDQIndexMatch - ) -> accum_type: - acc[item.similarity_info.distance] = acc.get( - item.similarity_info.distance, set() - ) - # Instead of storing the unhashable item.metadata dict, store its - # hash so we can compare using self.assertSetEqual - acc[item.similarity_info.distance].add(hash(frozenset(item.metadata))) - return acc - - # Convert results to distance -> set of metadata map - distance_to_result_items_map: accum_type = functools.reduce( - quality_indexed_dict_reducer, result, {} - ) - - # Convert expected to distance -> set of metadata map - distance_to_expected_items_map: accum_type = functools.reduce( - quality_indexed_dict_reducer, expected, {} - ) - - assert len(distance_to_expected_items_map) == len( - distance_to_result_items_map - ), "Unequal number of items in expected and results." - - for distance, result_items in distance_to_result_items_map.items(): - assert ( - distance in distance_to_expected_items_map - ), f"Unexpected distance {distance} found" - self.assertSetEqual(result_items, distance_to_expected_items_map[distance]) - - def test_search_index_for_matches(self): - entry_hash = test_entries[1][0] - result = self.index.query(entry_hash) - self.assertEqualPDQIndexMatchResults( - result, - [ - PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(0), test_entries[1][1] - ), - PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(16), test_entries[0][1] - ), - ], - ) - - def test_search_index_with_no_match(self): - query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - result = self.index.query(query_hash) - self.assertEqualPDQIndexMatchResults(result, []) - - def test_supports_pickling(self): - pickled_data = pickle.dumps(self.index) - assert pickled_data != None, "index does not support pickling to a data stream" - - reconstructed_index = pickle.loads(pickled_data) - assert ( - reconstructed_index != None - ), "index does not support unpickling from data stream" - assert ( - reconstructed_index.index.faiss_index != self.index.index.faiss_index - ), "unpickling should create it's own faiss index in memory" - - query = test_entries[0][0] - result = reconstructed_index.query(query) - self.assertEqualPDQIndexMatchResults( - result, - [ - PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(0), test_entries[1][1] - ), - PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(16), test_entries[0][1] - ), - ], - ) +@pytest.fixture +def index(): + return PDQIndex.build(test_entries) + +def assert_equal_pdq_index_match_results(result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]): + assert len(result) == len(expected), "search results not of expected length" + + def quality_indexed_dict_reducer( + acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch + ) -> t.Dict[int, t.Set[int]]: + acc[item.similarity_info.distance] = acc.get(item.similarity_info.distance, set()) + acc[item.similarity_info.distance].add(hash(frozenset(item.metadata))) + return acc + + distance_to_result_items_map = functools.reduce(quality_indexed_dict_reducer, result, {}) + distance_to_expected_items_map = functools.reduce(quality_indexed_dict_reducer, expected, {}) + + assert len(distance_to_expected_items_map) == len(distance_to_result_items_map), "Unequal number of items" + + for distance, result_items in distance_to_result_items_map.items(): + assert distance in distance_to_expected_items_map, f"Unexpected distance {distance} found" + assert result_items == distance_to_expected_items_map[distance] + +def test_search_index_for_matches(index): + entry_hash = test_entries[1][0] + result = index.query(entry_hash) + assert_equal_pdq_index_match_results( + result, + [ + PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]), + PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]), + ], + ) + +def test_search_index_with_no_match(index): + query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + result = index.query(query_hash) + assert_equal_pdq_index_match_results(result, []) + +def test_supports_pickling(index): + pickled_data = pickle.dumps(index) + assert pickled_data is not None, "index does not support pickling to a data stream" + + reconstructed_index = pickle.loads(pickled_data) + assert reconstructed_index is not None, "index does not support unpickling from data stream" + assert reconstructed_index.index.faiss_index != index.index.faiss_index, "unpickling should create its own faiss index in memory" + + query = test_entries[0][0] + result = reconstructed_index.query(query) + assert_equal_pdq_index_match_results( + result, + [ + PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]), + PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]), + ], + ) From 29c27d443c5d26d10b9c9d3ebaf7cecee091ef14 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Tue, 10 Dec 2024 00:41:47 +0200 Subject: [PATCH 6/8] Fix formatting --- .../signal_type/tests/test_hash_from_x.py | 12 ++++-- .../signal_type/tests/test_md5_hash.py | 1 + .../signal_type/tests/test_pdq_index.py | 37 +++++++++++++++---- .../signal_type/tests/test_raw_text.py | 4 +- .../signal_type/tests/test_url_md5_hash.py | 10 ++++- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py b/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py index ec1578079..9e20d3b16 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_hash_from_x.py @@ -30,7 +30,9 @@ def test_signal_names_unique(signal_type): """ name = signal_type.get_name() seen = set() # Using a set to automatically manage unique entries - assert name not in seen, f"Two signal types share the same name: {signal_type!r} and {seen}" + assert ( + name not in seen + ), f"Two signal types share the same name: {signal_type!r} and {seen}" seen.add(name) @@ -42,9 +44,13 @@ def test_signal_types_have_content(signal_type): assert signal_type.get_content_types(), f"{signal_type!r} has no content types" -@pytest.mark.parametrize("signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)]) +@pytest.mark.parametrize( + "signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)] +) def test_str_hashers_have_impl(signal_type): """ Check that each TextHasher has an implementation that produces output. """ - assert signal_type.hash_from_str("test string"), f"{signal_type!r} produced no output from hasher" + assert signal_type.hash_from_str( + "test string" + ), f"{signal_type!r} produced no output from hasher" diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py b/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py index 2611572ce..657ce9aa5 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_md5_hash.py @@ -9,6 +9,7 @@ "data", "sample-b.jpg" ) + @pytest.fixture def file_content(): """ diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py index 9922caacc..2f5152c67 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py @@ -31,29 +31,44 @@ ), ] + @pytest.fixture def index(): return PDQIndex.build(test_entries) -def assert_equal_pdq_index_match_results(result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]): + +def assert_equal_pdq_index_match_results( + result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch] +): assert len(result) == len(expected), "search results not of expected length" def quality_indexed_dict_reducer( acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch ) -> t.Dict[int, t.Set[int]]: - acc[item.similarity_info.distance] = acc.get(item.similarity_info.distance, set()) + acc[item.similarity_info.distance] = acc.get( + item.similarity_info.distance, set() + ) acc[item.similarity_info.distance].add(hash(frozenset(item.metadata))) return acc - distance_to_result_items_map = functools.reduce(quality_indexed_dict_reducer, result, {}) - distance_to_expected_items_map = functools.reduce(quality_indexed_dict_reducer, expected, {}) + distance_to_result_items_map = functools.reduce( + quality_indexed_dict_reducer, result, {} + ) + distance_to_expected_items_map = functools.reduce( + quality_indexed_dict_reducer, expected, {} + ) - assert len(distance_to_expected_items_map) == len(distance_to_result_items_map), "Unequal number of items" + assert len(distance_to_expected_items_map) == len( + distance_to_result_items_map + ), "Unequal number of items" for distance, result_items in distance_to_result_items_map.items(): - assert distance in distance_to_expected_items_map, f"Unexpected distance {distance} found" + assert ( + distance in distance_to_expected_items_map + ), f"Unexpected distance {distance} found" assert result_items == distance_to_expected_items_map[distance] + def test_search_index_for_matches(index): entry_hash = test_entries[1][0] result = index.query(entry_hash) @@ -65,18 +80,24 @@ def test_search_index_for_matches(index): ], ) + def test_search_index_with_no_match(index): query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" result = index.query(query_hash) assert_equal_pdq_index_match_results(result, []) + def test_supports_pickling(index): pickled_data = pickle.dumps(index) assert pickled_data is not None, "index does not support pickling to a data stream" reconstructed_index = pickle.loads(pickled_data) - assert reconstructed_index is not None, "index does not support unpickling from data stream" - assert reconstructed_index.index.faiss_index != index.index.faiss_index, "unpickling should create its own faiss index in memory" + assert ( + reconstructed_index is not None + ), "index does not support unpickling from data stream" + assert ( + reconstructed_index.index.faiss_index != index.index.faiss_index + ), "unpickling should create its own faiss index in memory" query = test_entries[0][0] result = reconstructed_index.query(query) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py index e6e0e1808..bbf115a79 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py @@ -49,4 +49,6 @@ def test_compare_hash(self, compare_hash_cases): def test_matches_str(self, matches_str_cases): for case in matches_str_cases: input_str, match_str, expected_match, threshold = case - assert self.TYPE.matches_str(input_str, match_str, threshold) == expected_match + assert ( + self.TYPE.matches_str(input_str, match_str, threshold) == expected_match + ) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py b/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py index d361bf065..548f1710b 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_url_md5_hash.py @@ -7,8 +7,14 @@ FULL_URL_TEST = "https://www.facebook.com/?user=123" URL_TEST_MD5 = "e359430911fe80c2dd526d3cca21da30" + def test_can_hash_simple_url(): - assert UrlMD5Signal.hash_from_str(URL_TEST) == URL_TEST_MD5, "MD5 hash does not match" + assert ( + UrlMD5Signal.hash_from_str(URL_TEST) == URL_TEST_MD5 + ), "MD5 hash does not match" + def test_can_hash_full_url(): - assert UrlMD5Signal.hash_from_str(FULL_URL_TEST) == URL_TEST_MD5, "MD5 hash does not match" + assert ( + UrlMD5Signal.hash_from_str(FULL_URL_TEST) == URL_TEST_MD5 + ), "MD5 hash does not match" From f30396524486ceccda97907153741548fb6ae0f5 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Tue, 10 Dec 2024 00:47:45 +0200 Subject: [PATCH 7/8] Update pdf test --- .../signal_type/tests/test_pdq_index.py | 56 +++++++++++++------ 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py index 2f5152c67..0240a5f42 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index.py @@ -40,7 +40,7 @@ def index(): def assert_equal_pdq_index_match_results( result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch] ): - assert len(result) == len(expected), "search results not of expected length" + assert len(result) == len(expected), "Search results not of expected length" def quality_indexed_dict_reducer( acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch @@ -60,44 +60,64 @@ def quality_indexed_dict_reducer( assert len(distance_to_expected_items_map) == len( distance_to_result_items_map - ), "Unequal number of items" + ), "Unequal number of distance groups" for distance, result_items in distance_to_result_items_map.items(): assert ( distance in distance_to_expected_items_map - ), f"Unexpected distance {distance} found" - assert result_items == distance_to_expected_items_map[distance] + ), f"Unexpected distance {distance} found in results" + assert result_items == distance_to_expected_items_map[distance], ( + f"Mismatch at distance {distance}. " + f"Expected: {distance_to_expected_items_map[distance]}, Got: {result_items}" + ) -def test_search_index_for_matches(index): - entry_hash = test_entries[1][0] +@pytest.mark.parametrize( + "entry_hash, expected_matches", + [ + ( + test_entries[1][0], + [ + PDQIndexMatch( + SignalSimilarityInfoWithIntDistance(0), test_entries[1][1] + ), + PDQIndexMatch( + SignalSimilarityInfoWithIntDistance(16), test_entries[0][1] + ), + ], + ), + ( + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + [], + ), + ], +) +def test_search_index(index, entry_hash, expected_matches): result = index.query(entry_hash) - assert_equal_pdq_index_match_results( - result, - [ - PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]), - PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]), - ], - ) + assert_equal_pdq_index_match_results(result, expected_matches) -def test_search_index_with_no_match(index): - query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" +def test_partial_match_below_threshold(index): + query_hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000" result = index.query(query_hash) assert_equal_pdq_index_match_results(result, []) def test_supports_pickling(index): pickled_data = pickle.dumps(index) - assert pickled_data is not None, "index does not support pickling to a data stream" + assert pickled_data is not None, "Index does not support pickling to a data stream" reconstructed_index = pickle.loads(pickled_data) assert ( reconstructed_index is not None - ), "index does not support unpickling from data stream" + ), "Index does not support unpickling from data stream" assert ( reconstructed_index.index.faiss_index != index.index.faiss_index - ), "unpickling should create its own faiss index in memory" + ), "Unpickling should create its own FAISS index in memory" + + assert ( + reconstructed_index.index_size == index.index_size + ), "Index size mismatch after unpickling" query = test_entries[0][0] result = reconstructed_index.query(query) From be67fa4aecb101d3cd8d81db01653a6032f663a1 Mon Sep 17 00:00:00 2001 From: ZeyadTarekk Date: Tue, 10 Dec 2024 00:49:34 +0200 Subject: [PATCH 8/8] fix test raw text --- .../signal_type/tests/test_raw_text.py | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py index bbf115a79..52ecc1f47 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_raw_text.py @@ -4,8 +4,6 @@ from threatexchange.signal_type.raw_text import RawTextSignal from threatexchange.signal_type.tests.signal_type_test_helper import MatchesStrAutoTest -from threatexchange.signal_type.raw_text import RawTextSignal - class TestRawTextSignal(MatchesStrAutoTest): TYPE = RawTextSignal @@ -24,31 +22,34 @@ def compare_hash_cases(self): @pytest.fixture def matches_str_cases(self): return [ - ("", ""), - ("a", "a"), - ("a", "b", False, 1), - ("aaaaaaaaaa", "a", False, 9), - ("a", "aaaaaaaaa", False, 8), - # Normalization removes spaces - ("a a a a a a a a a", "aaaaaaaaa", True, 0), - # Default threshold is 95% - ("a" * 19, "a" * 18 + "b", False, 1), - ("a" * 20, "a" * 19 + "b", True, 1), + ("", "", True, 0), # Empty strings match + ("a", "a", True, 0), # Identical single-character match + ("a", "b", False, 1), # Single-character mismatch + ("aaaaaaaaaa", "a", False, 9), # Longer string doesn't match shorter + ("a", "aaaaaaaaa", False, 8), # Shorter string doesn't match longer + ("a a a a a a a a a", "aaaaaaaaa", True, 0), # Normalization removes spaces + ("a" * 19, "a" * 18 + "b", False, 1), # Fails threshold 95% + ("a" * 20, "a" * 19 + "b", True, 1), # Meets threshold with 20 chars ] def test_validate_hash(self, validate_hash_cases): - for case in validate_hash_cases: - input_val, expected_hash = case - assert self.TYPE.validate_hash(input_val) == expected_hash + for input_val, expected_hash in validate_hash_cases: + assert self.TYPE.validate_hash(input_val) == expected_hash, ( + f"Expected {expected_hash} for input {input_val}, but got " + f"{self.TYPE.validate_hash(input_val)}" + ) def test_compare_hash(self, compare_hash_cases): - for case in compare_hash_cases: - input_val, expected_result = case - assert self.TYPE.compare_hash(input_val) == expected_result + for input_val, expected_result in compare_hash_cases: + assert self.TYPE.compare_hash(input_val) == expected_result, ( + f"Expected {expected_result} for input {input_val}, but got " + f"{self.TYPE.compare_hash(input_val)}" + ) def test_matches_str(self, matches_str_cases): - for case in matches_str_cases: - input_str, match_str, expected_match, threshold = case - assert ( - self.TYPE.matches_str(input_str, match_str, threshold) == expected_match + for input_str, match_str, expected_match, threshold in matches_str_cases: + result = self.TYPE.matches_str(input_str, match_str, threshold) + assert result == expected_match, ( + f"Expected {expected_match} for input ({input_str}, {match_str}) with " + f"threshold {threshold}, but got {result}" )