From 9ff3ba8cf2deb356ff014599ae097ff07dd1daeb Mon Sep 17 00:00:00 2001 From: Ramo Date: Wed, 22 Feb 2023 15:56:42 +1100 Subject: [PATCH] Update typing in args_validators (#730) --- dftimewolf/lib/args_validator.py | 46 +++++++++++++------------------- dftimewolf/lib/resources.py | 2 +- tests/lib/args_validator.py | 16 +++++------ 3 files changed, 27 insertions(+), 37 deletions(-) diff --git a/dftimewolf/lib/args_validator.py b/dftimewolf/lib/args_validator.py index 27bc469d..db601743 100644 --- a/dftimewolf/lib/args_validator.py +++ b/dftimewolf/lib/args_validator.py @@ -4,7 +4,7 @@ import ipaddress import re -from typing import Any, Dict, Optional, Union, Tuple +from typing import Any, Dict, Union, Tuple import datetime from urllib.parse import urlparse @@ -23,7 +23,7 @@ def __init__(self) -> None: @abc.abstractmethod def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]: + validator_params: Dict[str, Any]) -> Tuple[bool, str]: """Validate the parameter. Args: @@ -49,7 +49,7 @@ class CommaSeparatedValidator(AbstractValidator): def Validate(self, operand: str, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Split the string by commas if validator_params['comma_separated'] == True and validate each component in ValidateSingle. @@ -67,9 +67,6 @@ def Validate(self, Raises: errors.RecipeArgsValidatorError: An error in validation. """ - - if not validator_params: - validator_params = {} if 'comma_separated' not in validator_params: validator_params['comma_separated'] = False @@ -88,7 +85,7 @@ def Validate(self, @abc.abstractmethod def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate a single operand from a comma separated list. @@ -114,7 +111,7 @@ class DefaultValidator(AbstractValidator): def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]] + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Always passes.""" return True, '' @@ -143,7 +140,7 @@ class AWSRegionValidator(AbstractValidator): def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate operand is a valid AWS region. @@ -191,7 +188,7 @@ class AzureRegionValidator(AbstractValidator): def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]: + validator_params: Dict[str, Any]) -> Tuple[bool, str]: """Validate that operand is a valid Azure region. Args: @@ -251,7 +248,7 @@ class GCPZoneValidator(AbstractValidator): def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]: + validator_params: Dict[str, Any]) -> Tuple[bool, str]: """Validate that operand is a valid GCP zone. Args: @@ -275,7 +272,7 @@ class RegexValidator(CommaSeparatedValidator): def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate a string according to a regular expression. @@ -292,7 +289,7 @@ def ValidateSingle(self, Raises: errors.RecipeArgsValidatorError: If no regex is found to use. """ - if not validator_params or 'regex' not in validator_params: + if 'regex' not in validator_params: raise errors.RecipeArgsValidatorError( 'Missing validator parameter: regex') @@ -310,7 +307,7 @@ class SubnetValidator(CommaSeparatedValidator): def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate that operand is a valid subnet string. @@ -377,7 +374,7 @@ class DatetimeValidator(AbstractValidator): def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]: + validator_params: Dict[str, Any]) -> Tuple[bool, str]: """Validate that operand is a valid GCP zone. Args: @@ -395,7 +392,7 @@ def Validate(self, Raises: errors.RecipeArgsValidatorError: An error in validation. """ - if not validator_params or 'format_string' not in validator_params: + if 'format_string' not in validator_params: raise errors.RecipeArgsValidatorError( 'Missing validator parameter: format_string') @@ -464,7 +461,7 @@ class HostnameValidator(RegexValidator): def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate an FQDN. @@ -478,9 +475,6 @@ def ValidateSingle(self, boolean: True if operand is a valid FQDN, False otherwise. str: A message for validation failure. Only set if the boolean is false. """ - if not validator_params: - validator_params = {} - regexes = [self.FQDN_REGEX] if not validator_params.get(self.FQDN_ONLY_FLAG, False): regexes.append(self.HOSTNAME_REGEX) @@ -507,7 +501,7 @@ class GRRHostValidator(HostnameValidator): def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate a Grr host ID. @@ -520,8 +514,6 @@ def ValidateSingle(self, boolean: True if operand is a valid Grr ID, False otherwise. str: A message for validation failure. Only set if the boolean is false. """ - if not validator_params: - validator_params = {} validator_params['regex'] = self.GRR_REGEX bases = [RegexValidator, HostnameValidator] @@ -540,7 +532,7 @@ class URLValidator(HostnameValidator): def ValidateSingle(self, operand: str, - validator_params: Optional[Dict[str, Any]] = None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validates a URL. @@ -564,8 +556,6 @@ def ValidateSingle(self, except ValueError: pass - if not validator_params: - validator_params = {} validator_params['regex'] = self.HOSTNAME_REGEX bases = [RegexValidator, HostnameValidator] @@ -605,7 +595,7 @@ def RegisterValidator(self, validator: AbstractValidator) -> None: def Validate(self, operand: Any, - validator_params: Optional[Dict[str, Any]]=None + validator_params: Dict[str, Any] ) -> Tuple[bool, str]: """Validate a operand. @@ -621,7 +611,7 @@ def Validate(self, Raises: errors.RecipeArgsValidatorError: Raised on validator config errors. """ - if validator_params is None: + if 'format' not in validator_params: validator = self._default_validator else: if validator_params['format'] not in self._validators: diff --git a/dftimewolf/lib/resources.py b/dftimewolf/lib/resources.py index 6cc90ef1..7b8c2a37 100644 --- a/dftimewolf/lib/resources.py +++ b/dftimewolf/lib/resources.py @@ -11,7 +11,7 @@ class RecipeArgs: switch: str = '' help_text: str = '' default: Any = None - format: Dict[str, Any] = None # type: ignore + format: Dict[str, Any] = dataclasses.field(default_factory=dict) class Recipe(object): diff --git a/tests/lib/args_validator.py b/tests/lib/args_validator.py index 0bcce031..2a2665b9 100644 --- a/tests/lib/args_validator.py +++ b/tests/lib/args_validator.py @@ -344,7 +344,7 @@ def test_ValidateSuccess(self): 'grr-server' ] for fqdn in fqdns: - val, _ = self.validator.Validate(fqdn) + val, _ = self.validator.Validate(fqdn, {}) self.assertTrue(val) val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True}) @@ -354,7 +354,7 @@ def test_ValidationFailure(self): """Tests validation failures.""" fqdns = ['a-.com', '-a.com'] for fqdn in fqdns: - val, msg = self.validator.Validate(fqdn) + val, msg = self.validator.Validate(fqdn, {}) self.assertFalse(val) self.assertEqual(msg, f"'{fqdn}' is an invalid hostname.") @@ -396,7 +396,7 @@ def test_ValidateSuccess(self): 'grr-client-ubuntu.c.ramoj-playground.internal', 'grr-client'] for fqdn in fqdns: - val, _ = self.validator.Validate(fqdn) + val, _ = self.validator.Validate(fqdn, {}) self.assertTrue(val) val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True}) @@ -406,7 +406,7 @@ def test_ValidationFailure(self): """Tests validation failures.""" fqdns = ['a-.com', 'C.a', 'C.01234567890123456789'] for fqdn in fqdns: - val, msg = self.validator.Validate(fqdn) + val, msg = self.validator.Validate(fqdn, {}) self.assertFalse(val) self.assertEqual(msg, f"'{fqdn}' is an invalid Grr host ID.") @@ -443,7 +443,7 @@ def test_ValidateSuccess(self): 'https://grr.ramoj-playground.internal', ] for fqdn in fqdns: - val, _ = self.validator.Validate(fqdn) + val, _ = self.validator.Validate(fqdn, {}) self.assertTrue(val, f'{fqdn} failed validation') val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True}) @@ -457,7 +457,7 @@ def test_ValidationFailure(self): 'http://one.*.com' ] for fqdn in fqdns: - val, msg = self.validator.Validate(fqdn) + val, msg = self.validator.Validate(fqdn, {}) self.assertFalse(val) self.assertEqual(msg, f"'{fqdn}' is an invalid URL.") @@ -521,13 +521,13 @@ def test_Validation(self): def test_DefaultValidation(self): """Tests param validation with DefaultValidator.""" - val, _ = self.vm.Validate('operand') + val, _ = self.vm.Validate('operand', {}) self.assertTrue(val) def test_ValidationFailure(self): """Tests validation failure.""" val, msg = self.vm.Validate('invalid', - {'format': 'subnet', 'comma_separated': False}) + {'format': 'subnet', 'comma_separated': False}) self.assertFalse(val) self.assertEqual(msg, 'invalid is not a valid subnet.')