diff --git a/timesketch/lib/analyzers/feature_test.py b/timesketch/lib/analyzers/feature_test.py index bc9a799fd5..4b0e551bfd 100644 --- a/timesketch/lib/analyzers/feature_test.py +++ b/timesketch/lib/analyzers/feature_test.py @@ -14,13 +14,16 @@ """Unit tests for feature extraction.""" import os +import re import textwrap from typing import List, Dict import yaml import mock +from timesketch.lib import emojis from timesketch.lib.analyzers.feature import FeatureSketchPlugin +from timesketch.lib.analyzers.feature_plugins import regex_features from timesketch.lib.analyzers.sequence_sessionizer_test import _create_eventObj from timesketch.lib.testlib import BaseTest from timesketch.lib.testlib import MockDataStore @@ -151,3 +154,158 @@ def _create_mock_winevt_events(self) -> List[Dict]: events.append(security_4624_v2_event) return events + + # Copied from feature_extraction_test + + def _config_validation(self, config): + """Validate that all items of a config are valid.""" + query = config.get("query_string", config.get("query_dsl")) + self.assertIsNotNone(query) + self.assertIsInstance(query, str) + + attribute = config.get("attribute") + self.assertIsNotNone(attribute) + + store_as = config.get("store_as") + self.assertIsNotNone(store_as) + + expression = config.get("re") + self.assertIsNotNone(expression) + try: + _ = re.compile(expression) + except re.error as exception: + self.assertIsNone(exception) + + emojis_to_add = config.get("emojis") + if emojis_to_add: + self.assertIsInstance(emojis_to_add, (list, tuple)) + for emoji_name in emojis_to_add: + emoji_code = emojis.get_emoji(emoji_name) + self.assertNotEqual(emoji_code, "") + + tags = config.get("tags") + if tags: + self.assertIsInstance(tags, (list, tuple)) + + create_view = config.get("create_view") + if create_view: + self.assertIsInstance(create_view, bool) + + aggregate = config.get("aggregate") + if aggregate: + self.assertIsInstance(aggregate, bool) + + # TODO: Add tests for the feature extraction. + def test_config(self): + """Tests that the config file is valid.""" + config_file = os.path.join("data", "features.yaml") + self.assertTrue(os.path.isfile(config_file)) + + with open(config_file) as fh: + config = yaml.safe_load(fh) + + self.assertIsInstance(config, dict) + + for key, value in iter(config.items()): + self.assertIsInstance(key, str) + self.assertIsInstance(value, dict) + self._config_validation(value) + + # Mock the OpenSearch datastore. + @mock.patch("timesketch.lib.analyzers.interface.OpenSearchDataStore", MockDataStore) + def test_get_attribute_value(self): + """Test function _get_attribute_value().""" + analyzer = FeatureSketchPlugin( + index_name="test_index", sketch_id=1, timeline_id=1 + ) + plugin = regex_features.RegexFeatureExtractionPlugin(analyzer) + + current_val = ["hello"] + extracted_value = ["hello"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val=current_val, + extracted_value=extracted_value, + keep_multi=True, + merge_values=True, + type_list=True, + ) + new_val.sort() + + self.assertEqual(new_val, ["hello"]) + + current_val = ["hello"] + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, True, True, True + ) + new_val.sort() + + self.assertEqual(new_val, ["hello", "hello2", "hello3"]) + + current_val = ["hello"] + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, False, True, True + ) + new_val.sort() + + self.assertEqual(new_val, ["hello", "hello2"]) + + current_val = ["hello"] + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, False, False, True + ) + new_val.sort() + + self.assertEqual(new_val, ["hello2"]) + + current_val = ["hello"] + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, True, False, True + ) + new_val.sort() + + self.assertEqual(new_val, ["hello2", "hello3"]) + + current_val = "hello" + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, True, True, False + ) + + self.assertEqual(new_val, "hello,hello2,hello3") + + current_val = "hello" + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, False, True, False + ) + + self.assertEqual(new_val, "hello,hello2") + + current_val = "hello" + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, True, False, False + ) + + self.assertEqual(new_val, "hello2,hello3") + + current_val = "hello" + extracted_value = ["hello2", "hello3"] + # pylint: disable=protected-access + new_val = plugin._get_attribute_value( + current_val, extracted_value, False, False, False + ) + + self.assertEqual(new_val, "hello2")