From 02d45de3721d6f0b01d9ec6f62ce099c42218dbe Mon Sep 17 00:00:00 2001 From: Md Azam Date: Wed, 22 Nov 2023 09:30:15 -0400 Subject: [PATCH 1/2] Fix parameter assignment in error handling function (#1616) --- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/connector.py | 88 +++++++++++++------ .../stix_transmission/error_mapper.py | 36 ++++++++ .../stix_transmission/test_crowdstrike.py | 80 ++++++++--------- tests/utils/async_utils.py | 7 +- 7 files changed, 145 insertions(+), 72 deletions(-) create mode 100644 stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py diff --git a/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py b/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py index 05dfa712a..209c52843 100644 --- a/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py @@ -45,4 +45,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py b/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py index 83a08dc33..7ef001841 100644 --- a/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py @@ -43,4 +43,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py b/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py index 0ff1c75b4..7c73ce608 100644 --- a/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py @@ -44,4 +44,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/crowdstrike/stix_transmission/connector.py b/stix_shifter_modules/crowdstrike/stix_transmission/connector.py index 71347b596..b2b532358 100644 --- a/stix_shifter_modules/crowdstrike/stix_transmission/connector.py +++ b/stix_shifter_modules/crowdstrike/stix_transmission/connector.py @@ -3,6 +3,11 @@ from .api_client import APIClient from stix_shifter_utils.utils.error_response import ErrorResponder from stix_shifter_utils.utils import logger +from requests.exceptions import ConnectionError + + +class QueryException(Exception): + pass class Connector(BaseJsonSyncConnector): @@ -31,45 +36,76 @@ def _handle_errors(self, response, return_obj): """ response_code = response.code response_txt = response.read().decode('utf-8') + response_type = response.headers.get('Content-Type') + response_dict = {} if 200 <= response_code < 300: return_obj['success'] = True return_obj['data'] = response_txt return return_obj - elif ErrorResponder.is_plain_string(response_txt): - ErrorResponder.fill_error(return_obj, message=response_txt) - raise Exception(return_obj) - elif ErrorResponder.is_json_string(response_txt): - response_json = json.loads(response_txt) - ErrorResponder.fill_error(return_obj, response_json, ['reason'], connector=self.connector) - raise Exception(return_obj) - else: - raise Exception(return_obj) + elif response_code >= 400: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'ValidationError' + response_dict['message'] = error_response['errors'][0]['message'] + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + raise QueryException(return_obj) + elif response_type == 'text/html': + error = ConnectionError(f'Error connecting the datasource: {response_txt}') + ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector) + raise QueryException(return_obj) + else: + raise Exception(response_txt) + async def ping_connection(self): response_txt = None return_obj = {} + response_dict = {} try: response = await self.api_client.ping_box() response_code = response.code response_txt = response.read().decode('utf-8') + response_type = response.headers.get('Content-Type') if 199 < response_code < 300: return_obj['success'] = True - elif isinstance(json.loads(response_txt), dict): - response_error_ping = json.loads(response_txt) - response_dict = response_error_ping['errors'][0] - ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + elif response_code == 401: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'AuthenticationError' + response_dict['message'] = error_response['errors'][0]['message'] + self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response)) + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + else: + raise Exception(response_txt) + elif response_code == 400: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'ValidationError' + response_dict['message'] = error_response['errors'][0]['message'] + self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response)) + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + else: + raise Exception(response_txt) else: - raise Exception(response_txt) + if response_type == 'application/json': + response_error_ping = json.loads(response_txt) + response_dict = response_error_ping['errors'][0] + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + elif response_type == 'text/html': + error = ConnectionError(f'Error connecting the datasource: {response_txt}') + ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector) + else: + raise Exception(response_txt) except Exception as e: if response_txt is not None: - ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector) - self.logger.error('can not parse response: ' + str(response_txt)) + ErrorResponder.fill_error(return_obj, message='unexpected exception: ' + str(response_txt), connector=self.connector) + self.logger.error('Can not parse response Crowdstrike error: ' + str(response_txt)) else: raise e - + return return_obj - + async def send_info_request_and_handle_errors(self, ids_lst): return_obj = dict() response = await self.api_client.get_detections_info(ids_lst) @@ -142,7 +178,6 @@ async def create_results_connection(self, query, offset, length): :param offset: int,offset value :param length: int,length value""" result_limit = offset + length - response_txt = None ids_obj = dict() return_obj = dict() table_event_data = [] @@ -195,10 +230,13 @@ async def create_results_connection(self, query, offset, length): if not return_obj.get('success'): return_obj['success'] = True return return_obj - + except QueryException as ex: + return ex.args[0] except Exception as ex: - if response_txt is not None: - ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector) - self.logger.error('can not parse response: ' + str(response_txt)) - else: - raise ex + error_dict = {} + error_dict['type'] = 'AttributeError' + error_dict['message'] = 'Error while parsing API response: ' + str(ex) + ErrorResponder.fill_error(return_obj, error_dict, ['message'], connector=self.connector) + self.logger.error('Unexpected exception from Crowdstrike datasource: ' + str(ex)) + + return return_obj diff --git a/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py b/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py new file mode 100644 index 000000000..703d3912d --- /dev/null +++ b/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py @@ -0,0 +1,36 @@ +from stix_shifter_utils.utils.error_mapper_base import ErrorMapperBase +from stix_shifter_utils.utils.error_response import ErrorCode +from stix_shifter_utils.utils import logger + +error_mapping = { + "ConnectionError": ErrorCode.TRANSMISSION_REMOTE_SYSTEM_IS_UNAVAILABLE, + "AuthenticationError": ErrorCode.TRANSMISSION_AUTH_CREDENTIALS, + "ValidationError": ErrorCode.TRANSMISSION_QUERY_LOGICAL_ERROR, + "AttributeError": ErrorCode.TRANSMISSION_INVALID_PARAMETER, +} + + +class ErrorMapper: + """ + Set Error Code + """ + logger = logger.set_logger(__name__) + DEFAULT_ERROR = ErrorCode.TRANSMISSION_MODULE_DEFAULT_ERROR + + @staticmethod + def set_error_code(json_data, return_obj, connector=None): + err_type = None + try: + err_type = json_data['type'] + except KeyError: + pass + + error_type = ErrorMapper.DEFAULT_ERROR + + if err_type in error_mapping: + error_type = error_mapping.get(err_type) + + if error_type == ErrorMapper.DEFAULT_ERROR: + ErrorMapper.logger.error("failed to map: %s", str(json_data)) + + ErrorMapperBase.set_error_code(return_obj, error_type, connector=connector) diff --git a/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py b/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py index 6503f799a..b582bc8e0 100644 --- a/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py +++ b/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py @@ -1,6 +1,4 @@ -import json import unittest -from unittest.mock import ANY from unittest.mock import patch from tests.utils.async_utils import get_mock_response from stix_shifter_modules.crowdstrike.entry_point import EntryPoint @@ -18,6 +16,8 @@ 'host': 'api.crowdstrike.com' } +headers = {'Content-Type': 'application/json'} + @patch('stix_shifter_modules.crowdstrike.stix_transmission.api_client.APIClient.get_detections_IDs', autospec=True) class TestCrowdStrikeConnection(unittest.TestCase, object): @@ -54,21 +54,22 @@ def test_create_query_connection(self, mock_api_client): def test_no_results_response(self, mock_requests_response): mocked_return_value = """ -{"terms": ["process_name:notepad.exe"], - "results": [], - "elapsed": 0.01921701431274414, - "comprehensive_search": true, - "all_segments": true, - "total_results": 0, - "highlights": [], - "facets": {}, - "tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]}, - "start": 0, - "incomplete_results": false, - "filtered": {} -} -""" - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + { + "terms": ["process_name:notepad.exe"], + "results": [], + "elapsed": 0.01921701431274414, + "comprehensive_search": true, + "all_segments": true, + "total_results": 0, + "highlights": [], + "facets": {}, + "tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]}, + "start": 0, + "incomplete_results": false, + "filtered": {} + } + """ + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:notepad.exe")[0] @@ -80,30 +81,28 @@ def test_no_results_response(self, mock_requests_response): assert 'data' in results_response assert len(results_response['data']) == 0 - - def test_one_results_response(self, mock_requests_response): mocked_return_value = """ -{ - "terms": [ - "process_name:cmd.exe", - "start:[2019-01-22T00:00:00 TO *]" - ], - "results": [], - "elapsed": 0.05147600173950195, - "comprehensive_search": true, - "all_segments": true, - "total_results": 1, - "highlights": [], - "facets": {}, - "tagged_pids": {}, - "start": 0, - "incomplete_results": false, - "filtered": {} -} -""" - - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + { + "terms": [ + "process_name:cmd.exe", + "start:[2019-01-22T00:00:00 TO *]" + ], + "results": [], + "elapsed": 0.05147600173950195, + "comprehensive_search": true, + "all_segments": true, + "total_results": 1, + "highlights": [], + "facets": {}, + "tagged_pids": {}, + "start": 0, + "incomplete_results": false, + "filtered": {} + } + """ + + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:cmd.exe start:[2019-01-22 TO *]")[0] @@ -115,10 +114,9 @@ def test_one_results_response(self, mock_requests_response): assert 'data' in results_response assert len(results_response['data']) == 0 - def test_transmit_limit_and_sort(self, mock_requests_response): mocked_return_value = '{"reason": "query_syntax_error"}' - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:cmd.exe")[0] diff --git a/tests/utils/async_utils.py b/tests/utils/async_utils.py index 61ae5c7f3..aafef083f 100644 --- a/tests/utils/async_utils.py +++ b/tests/utils/async_utils.py @@ -1,13 +1,14 @@ -def get_mock_response(status_code, content=None, return_type='str', response=None): - return RequestMockResponse(status_code, content, return_type, response) +def get_mock_response(status_code, content=None, return_type='str', response=None, headers=None): + return RequestMockResponse(status_code, content, return_type, response, headers) def get_aws_mock_response(obj): return AWSComposeMockResponse(obj) class RequestMockResponse: - def __init__(self, status_code, content, return_type='str', response=None): + def __init__(self, status_code, content, return_type='str', response=None, headers=None): self.code = status_code + self.headers = headers self.content = content self.response = response self.object = response From c455ab2c7d1fb613d69971c05fce58eddd83b743 Mon Sep 17 00:00:00 2001 From: Md Azam Date: Wed, 22 Nov 2023 12:17:54 -0400 Subject: [PATCH 2/2] Remove future timestamp qualifier conditions (#1619) --- .../stix_translation/query_constructor.py | 4 --- .../test_aws_guardduty_stix_to_query.py | 27 +++++++++++++++---- .../stix_translation/query_constructor.py | 4 --- .../test_okta_stix_to_query.py | 11 ++++---- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/stix_shifter_modules/aws_guardduty/stix_translation/query_constructor.py b/stix_shifter_modules/aws_guardduty/stix_translation/query_constructor.py index e472c794a..56e6a1edc 100644 --- a/stix_shifter_modules/aws_guardduty/stix_translation/query_constructor.py +++ b/stix_shifter_modules/aws_guardduty/stix_translation/query_constructor.py @@ -372,13 +372,9 @@ def _check_time_range_values(time_range_list): checks for valid start and stop time :param time_range_list: list """ - utc_timestamp = STOP_TIME.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' - converted_utc_timestamp = QueryStringPatternTranslator._format_datetime(utc_timestamp) converted_timestamp = [] for timestamp in time_range_list: converted_time = QueryStringPatternTranslator._format_datetime(timestamp) - if converted_time > converted_utc_timestamp: - raise StartStopQualifierValueException('Start/Stop time should not be in the future UTC timestamp') converted_timestamp.append(converted_time) if converted_timestamp[0] >= converted_timestamp[1]: raise StartStopQualifierValueException('Start time should be lesser than Stop time') diff --git a/stix_shifter_modules/aws_guardduty/tests/stix_translation/test_aws_guardduty_stix_to_query.py b/stix_shifter_modules/aws_guardduty/tests/stix_translation/test_aws_guardduty_stix_to_query.py index 4fb644afc..aa0bf80b1 100644 --- a/stix_shifter_modules/aws_guardduty/tests/stix_translation/test_aws_guardduty_stix_to_query.py +++ b/stix_shifter_modules/aws_guardduty/tests/stix_translation/test_aws_guardduty_stix_to_query.py @@ -654,13 +654,30 @@ def test_multiple_observation_with_single_qualifier_with_precedence_bracket(self queries = _remove_timestamp_from_query(queries) self._test_query_assertions(query, queries) - def test_invalid_qualifier_with_future_timestamp(self): + def test_timestamp_qualifier(self): stix_pattern = "[network-traffic:src_port >= 32794]START t'2023-01-19T11:00:00.000Z' " \ "STOP t'2024-02-07T11:00:00.003Z'" - result = translation.translate('aws_guardduty', 'query', '{}', stix_pattern) - assert result['success'] is False - assert "translation_error" == result['code'] - assert 'Start/Stop time should not be in the future UTC timestamp' in result['error'] + queries = { + "queries": [ + { + "FindingCriteria": { + "Criterion": { + "service.action.networkConnectionAction.localPortDetails.port": { + "GreaterThanOrEqual": 32794 + }, + "updatedAt": { + "GreaterThanOrEqual": 1674126000000, + "LessThanOrEqual": 1707303600003 + } + } + } + } + ] + } + query = translation.translate('aws_guardduty', 'query', '{}', stix_pattern) + query = _remove_timestamp_from_query(query) + queries = _remove_timestamp_from_query(queries) + self._test_query_assertions(query, queries) def test_stop_time_lesser_than_start_time(self): stix_pattern = "[network-traffic:src_port >= 32794]START t'2023-01-19T11:00:00.000Z' " \ diff --git a/stix_shifter_modules/okta/stix_translation/query_constructor.py b/stix_shifter_modules/okta/stix_translation/query_constructor.py index 1d2df7b5d..21d733f8e 100644 --- a/stix_shifter_modules/okta/stix_translation/query_constructor.py +++ b/stix_shifter_modules/okta/stix_translation/query_constructor.py @@ -230,13 +230,9 @@ def _check_time_range_values(time_range_list): checks for valid start and stop time :param time_range_list: list """ - utc_timestamp = STOP_TIME.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' - converted_utc_timestamp = QueryStringPatternTranslator._format_datetime(utc_timestamp) converted_timestamp = [] for timestamp in time_range_list: converted_time = QueryStringPatternTranslator._format_datetime(timestamp) - if converted_time > converted_utc_timestamp: - raise StartStopQualifierValueException('Start/Stop time should not be in the future UTC timestamp') converted_timestamp.append(converted_time) if converted_timestamp[0] >= converted_timestamp[1]: raise StartStopQualifierValueException('Start time should be lesser than Stop time') diff --git a/stix_shifter_modules/okta/test/stix_translation/test_okta_stix_to_query.py b/stix_shifter_modules/okta/test/stix_translation/test_okta_stix_to_query.py index cf3864a21..6e9755b77 100644 --- a/stix_shifter_modules/okta/test/stix_translation/test_okta_stix_to_query.py +++ b/stix_shifter_modules/okta/test/stix_translation/test_okta_stix_to_query.py @@ -317,13 +317,14 @@ def test_wildcard_characters_like_operator(self): queries = _remove_timestamp_from_query(queries) self._test_query_assertions(query, queries) - def test_invalid_qualifier_with_future_timestamp(self): + def test_timestamp_qualifier(self): stix_pattern = "[domain-name:value LIKE 'amazonaws.com'] " \ "START t'2023-01-19T11:00:00.000Z' STOP t'2024-02-07T11:00:00.003Z'" - result = translation.translate('okta', 'query', '{}', stix_pattern) - assert result['success'] is False - assert "translation_error" == result['code'] - assert 'Start/Stop time should not be in the future UTC timestamp' in result['error'] + query = translation.translate('okta', 'query', '{}', stix_pattern) + query['queries'] = _remove_timestamp_from_query(query['queries']) + queries = ["filter=securityContext.domain co \"amazonaws.com\" &since=2023-01-19T11:00:00.000Z&until=2024-02-07T11:00:00.003Z"] + queries = _remove_timestamp_from_query(queries) + self._test_query_assertions(query, queries) def test_invalid_operator_for_integer_type_field(self): stix_pattern = "[autonomous-system:number LIKE '50']"