Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some tests to pytest #1693

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest

import pytest
from threatexchange.signal_type import (
md5,
raw_text,
Expand All @@ -13,38 +12,45 @@
from threatexchange.signal_type.signal_base import TextHasher


class SignalTypeHashTest(unittest.TestCase):
# List of signal types to test
SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]


@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST)
def test_signal_names_unique(signal_type):
"""
Sanity check for signal type hashing methods.
Verify that each signal type has a unique name.
"""
name = signal_type.get_name()
seen = set() # Using a set to automatically manage unique entries
assert (
name not in seen
), f"Two signal types share the same name: {signal_type!r} and {seen}"
seen.add(name)

# TODO - maybe make a metaclass for this to automatically detect?
SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]

def test_signal_names_unique(self):
seen = {}
for s in self.SIGNAL_TYPES_TO_TEST:
name = s.get_name()
assert (
name not in seen
), f"Two signal types share the same name: {s!r} and {seen[name]!r}"

def test_signal_types_have_content(self):
for s in self.SIGNAL_TYPES_TO_TEST:
assert s.get_content_types(), "{s!r} has no content types"

def test_str_hashers_have_impl(self):
text_hashers = [
s for s in self.SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)
]
for s in text_hashers:
assert s.hash_from_str(
"test string"
), "{s!r} produced no output from hasher"

@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST)
def test_signal_types_have_content(signal_type):
"""
Ensure that each signal type has associated content types.
"""
assert signal_type.get_content_types(), f"{signal_type!r} has no content types"


@pytest.mark.parametrize(
"signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)]
)
def test_str_hashers_have_impl(signal_type):
"""
Check that each TextHasher has an implementation that produces output.
"""
assert signal_type.hash_from_str(
"test string"
), f"{signal_type!r} produced no output from hasher"
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pathlib

import pytest
from threatexchange.signal_type.md5 import VideoMD5Signal

# Define the test file path
TEST_FILE = pathlib.Path(__file__).parent.parent.parent.parent.joinpath(
"data", "sample-b.jpg"
)


class VideoMD5SignalTestCase(unittest.TestCase):
def setUp(self):
self.a_file = open(TEST_FILE, "rb")
@pytest.fixture
def file_content():
"""
Fixture to open and yield file content for testing,
then close the file after the test.
"""
with open(TEST_FILE, "rb") as f:
yield f.read()
Comment on lines +13 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that you need a fixture of this - the test can just open itself.

Fixtures are helpful when you are sharing setup between tests, and here's there is only one test.


def tearDown(self):
self.a_file.close()

def test_can_hash_simple_files(self):
assert "d35c785545392755e7e4164457657269" == VideoMD5Signal.hash_from_bytes(
self.a_file.read()
), "MD5 hash does not match"
def test_can_hash_simple_files(file_content):
"""
Test that the VideoMD5Signal produces the expected hash.
"""
expected_hash = "d35c785545392755e7e4164457657269"
computed_hash = VideoMD5Signal.hash_from_bytes(file_content)
assert computed_hash == expected_hash, "MD5 hash does not match"
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pickle
import typing as t
import pytest
import functools

from threatexchange.signal_type.index import (
Expand All @@ -13,139 +11,72 @@
test_entries = [
(
"0000000000000000000000000000000000000000000000000000000000000000",
dict(
{
"hash_type": "pdq",
"system_id": 9,
}
),
{"hash_type": "pdq", "system_id": 9},
),
(
"000000000000000000000000000000000000000000000000000000000000ffff",
dict(
{
"hash_type": "pdq",
"system_id": 8,
}
),
{"hash_type": "pdq", "system_id": 8},
),
(
"0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f",
dict(
{
"hash_type": "pdq",
"system_id": 7,
}
),
{"hash_type": "pdq", "system_id": 7},
),
(
"f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0",
dict(
{
"hash_type": "pdq",
"system_id": 6,
}
),
{"hash_type": "pdq", "system_id": 6},
),
(
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
dict(
{
"hash_type": "pdq",
"system_id": 5,
}
),
{"hash_type": "pdq", "system_id": 5},
),
]


class TestPDQIndex(unittest.TestCase):
def setUp(self):
self.index = PDQIndex.build(test_entries)
@pytest.fixture
def index():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignorable: While this might be a misuse of feature, fixtures are basically a 1:1 mapping for setUp, so I think this a faithful translation

return PDQIndex.build(test_entries)

def assertEqualPDQIndexMatchResults(
self, result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]
):
self.assertEqual(
len(result), len(expected), "search results not of expected length"
)

accum_type = t.Dict[int, t.Set[int]]

# Between python 3.8.6 and 3.8.11, something caused the order of results
# from the index to change. This was noticed for items which had the
# same distance. To allow for this, we convert result and expected
# arrays from
# [PDQIndexMatch, PDQIndexMatch] to { distance: <set of PDQIndexMatch.metadata hash> }
# This allows you to compare [PDQIndexMatch A, PDQIndexMatch B] with
# [PDQIndexMatch B, PDQIndexMatch A] as long as A.distance == B.distance.
def quality_indexed_dict_reducer(
acc: accum_type, item: PDQIndexMatch
) -> accum_type:
acc[item.similarity_info.distance] = acc.get(
item.similarity_info.distance, set()
)
# Instead of storing the unhashable item.metadata dict, store its
# hash so we can compare using self.assertSetEqual
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

# Convert results to distance -> set of metadata map
distance_to_result_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, result, {}
)
def assert_equal_pdq_index_match_results(
result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]
):
assert len(result) == len(expected), "Search results not of expected length"

# Convert expected to distance -> set of metadata map
distance_to_expected_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, expected, {}
def quality_indexed_dict_reducer(
acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch
) -> t.Dict[int, t.Set[int]]:
acc[item.similarity_info.distance] = acc.get(
item.similarity_info.distance, set()
)
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

assert len(distance_to_expected_items_map) == len(
distance_to_result_items_map
), "Unequal number of items in expected and results."

for distance, result_items in distance_to_result_items_map.items():
assert (
distance in distance_to_expected_items_map
), f"Unexpected distance {distance} found"
self.assertSetEqual(result_items, distance_to_expected_items_map[distance])

def test_search_index_for_matches(self):
entry_hash = test_entries[1][0]
result = self.index.query(entry_hash)
self.assertEqualPDQIndexMatchResults(
result,
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
),
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)

def test_search_index_with_no_match(self):
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
result = self.index.query(query_hash)
self.assertEqualPDQIndexMatchResults(result, [])
distance_to_result_items_map = functools.reduce(
quality_indexed_dict_reducer, result, {}
)
distance_to_expected_items_map = functools.reduce(
quality_indexed_dict_reducer, expected, {}
)

def test_supports_pickling(self):
pickled_data = pickle.dumps(self.index)
assert pickled_data != None, "index does not support pickling to a data stream"
assert len(distance_to_expected_items_map) == len(
distance_to_result_items_map
), "Unequal number of distance groups"

reconstructed_index = pickle.loads(pickled_data)
assert (
reconstructed_index != None
), "index does not support unpickling from data stream"
for distance, result_items in distance_to_result_items_map.items():
assert (
reconstructed_index.index.faiss_index != self.index.index.faiss_index
), "unpickling should create it's own faiss index in memory"
distance in distance_to_expected_items_map
), f"Unexpected distance {distance} found in results"
assert result_items == distance_to_expected_items_map[distance], (
f"Mismatch at distance {distance}. "
f"Expected: {distance_to_expected_items_map[distance]}, Got: {result_items}"
)


query = test_entries[0][0]
result = reconstructed_index.query(query)
self.assertEqualPDQIndexMatchResults(
result,
@pytest.mark.parametrize(
"entry_hash, expected_matches",
[
(
test_entries[1][0],
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
Expand All @@ -154,4 +85,46 @@ def test_supports_pickling(self):
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)
),
(
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
[],
),
],
)
def test_search_index(index, entry_hash, expected_matches):
result = index.query(entry_hash)
assert_equal_pdq_index_match_results(result, expected_matches)


def test_partial_match_below_threshold(index):
query_hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000"
result = index.query(query_hash)
assert_equal_pdq_index_match_results(result, [])


def test_supports_pickling(index):
pickled_data = pickle.dumps(index)
assert pickled_data is not None, "Index does not support pickling to a data stream"

reconstructed_index = pickle.loads(pickled_data)
assert (
reconstructed_index is not None
), "Index does not support unpickling from data stream"
assert (
reconstructed_index.index.faiss_index != index.index.faiss_index
), "Unpickling should create its own FAISS index in memory"

assert (
reconstructed_index.index_size == index.index_size
), "Index size mismatch after unpickling"

query = test_entries[0][0]
result = reconstructed_index.query(query)
assert_equal_pdq_index_match_results(
result,
[
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]),
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]),
],
)
Loading
Loading