Skip to content

Commit

Permalink
Update typing in args_validators (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
ramo-j authored Feb 22, 2023
1 parent cf351bd commit 9ff3ba8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 37 deletions.
46 changes: 18 additions & 28 deletions dftimewolf/lib/args_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ipaddress
import re

from typing import Any, Dict, Optional, Union, Tuple
from typing import Any, Dict, Union, Tuple

import datetime
from urllib.parse import urlparse
Expand All @@ -23,7 +23,7 @@ def __init__(self) -> None:
@abc.abstractmethod
def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]:
validator_params: Dict[str, Any]) -> Tuple[bool, str]:
"""Validate the parameter.
Args:
Expand All @@ -49,7 +49,7 @@ class CommaSeparatedValidator(AbstractValidator):

def Validate(self,
operand: str,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Split the string by commas if validator_params['comma_separated'] == True
and validate each component in ValidateSingle.
Expand All @@ -67,9 +67,6 @@ def Validate(self,
Raises:
errors.RecipeArgsValidatorError: An error in validation.
"""

if not validator_params:
validator_params = {}
if 'comma_separated' not in validator_params:
validator_params['comma_separated'] = False

Expand All @@ -88,7 +85,7 @@ def Validate(self,
@abc.abstractmethod
def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]]
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate a single operand from a comma separated list.
Expand All @@ -114,7 +111,7 @@ class DefaultValidator(AbstractValidator):

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Always passes."""
return True, ''
Expand Down Expand Up @@ -143,7 +140,7 @@ class AWSRegionValidator(AbstractValidator):

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate operand is a valid AWS region.
Expand Down Expand Up @@ -191,7 +188,7 @@ class AzureRegionValidator(AbstractValidator):

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]:
validator_params: Dict[str, Any]) -> Tuple[bool, str]:
"""Validate that operand is a valid Azure region.
Args:
Expand Down Expand Up @@ -251,7 +248,7 @@ class GCPZoneValidator(AbstractValidator):

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]:
validator_params: Dict[str, Any]) -> Tuple[bool, str]:
"""Validate that operand is a valid GCP zone.
Args:
Expand All @@ -275,7 +272,7 @@ class RegexValidator(CommaSeparatedValidator):

def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate a string according to a regular expression.
Expand All @@ -292,7 +289,7 @@ def ValidateSingle(self,
Raises:
errors.RecipeArgsValidatorError: If no regex is found to use.
"""
if not validator_params or 'regex' not in validator_params:
if 'regex' not in validator_params:
raise errors.RecipeArgsValidatorError(
'Missing validator parameter: regex')

Expand All @@ -310,7 +307,7 @@ class SubnetValidator(CommaSeparatedValidator):

def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]]
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate that operand is a valid subnet string.
Expand Down Expand Up @@ -377,7 +374,7 @@ class DatetimeValidator(AbstractValidator):

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]) -> Tuple[bool, str]:
validator_params: Dict[str, Any]) -> Tuple[bool, str]:
"""Validate that operand is a valid GCP zone.
Args:
Expand All @@ -395,7 +392,7 @@ def Validate(self,
Raises:
errors.RecipeArgsValidatorError: An error in validation.
"""
if not validator_params or 'format_string' not in validator_params:
if 'format_string' not in validator_params:
raise errors.RecipeArgsValidatorError(
'Missing validator parameter: format_string')

Expand Down Expand Up @@ -464,7 +461,7 @@ class HostnameValidator(RegexValidator):

def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate an FQDN.
Expand All @@ -478,9 +475,6 @@ def ValidateSingle(self,
boolean: True if operand is a valid FQDN, False otherwise.
str: A message for validation failure. Only set if the boolean is false.
"""
if not validator_params:
validator_params = {}

regexes = [self.FQDN_REGEX]
if not validator_params.get(self.FQDN_ONLY_FLAG, False):
regexes.append(self.HOSTNAME_REGEX)
Expand All @@ -507,7 +501,7 @@ class GRRHostValidator(HostnameValidator):

def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate a Grr host ID.
Expand All @@ -520,8 +514,6 @@ def ValidateSingle(self,
boolean: True if operand is a valid Grr ID, False otherwise.
str: A message for validation failure. Only set if the boolean is false.
"""
if not validator_params:
validator_params = {}
validator_params['regex'] = self.GRR_REGEX

bases = [RegexValidator, HostnameValidator]
Expand All @@ -540,7 +532,7 @@ class URLValidator(HostnameValidator):

def ValidateSingle(self,
operand: str,
validator_params: Optional[Dict[str, Any]] = None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validates a URL.
Expand All @@ -564,8 +556,6 @@ def ValidateSingle(self,
except ValueError:
pass

if not validator_params:
validator_params = {}
validator_params['regex'] = self.HOSTNAME_REGEX

bases = [RegexValidator, HostnameValidator]
Expand Down Expand Up @@ -605,7 +595,7 @@ def RegisterValidator(self, validator: AbstractValidator) -> None:

def Validate(self,
operand: Any,
validator_params: Optional[Dict[str, Any]]=None
validator_params: Dict[str, Any]
) -> Tuple[bool, str]:
"""Validate a operand.
Expand All @@ -621,7 +611,7 @@ def Validate(self,
Raises:
errors.RecipeArgsValidatorError: Raised on validator config errors.
"""
if validator_params is None:
if 'format' not in validator_params:
validator = self._default_validator
else:
if validator_params['format'] not in self._validators:
Expand Down
2 changes: 1 addition & 1 deletion dftimewolf/lib/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RecipeArgs:
switch: str = ''
help_text: str = ''
default: Any = None
format: Dict[str, Any] = None # type: ignore
format: Dict[str, Any] = dataclasses.field(default_factory=dict)


class Recipe(object):
Expand Down
16 changes: 8 additions & 8 deletions tests/lib/args_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_ValidateSuccess(self):
'grr-server'
]
for fqdn in fqdns:
val, _ = self.validator.Validate(fqdn)
val, _ = self.validator.Validate(fqdn, {})
self.assertTrue(val)

val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True})
Expand All @@ -354,7 +354,7 @@ def test_ValidationFailure(self):
"""Tests validation failures."""
fqdns = ['a-.com', '-a.com']
for fqdn in fqdns:
val, msg = self.validator.Validate(fqdn)
val, msg = self.validator.Validate(fqdn, {})
self.assertFalse(val)
self.assertEqual(msg, f"'{fqdn}' is an invalid hostname.")

Expand Down Expand Up @@ -396,7 +396,7 @@ def test_ValidateSuccess(self):
'grr-client-ubuntu.c.ramoj-playground.internal',
'grr-client']
for fqdn in fqdns:
val, _ = self.validator.Validate(fqdn)
val, _ = self.validator.Validate(fqdn, {})
self.assertTrue(val)

val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True})
Expand All @@ -406,7 +406,7 @@ def test_ValidationFailure(self):
"""Tests validation failures."""
fqdns = ['a-.com', 'C.a', 'C.01234567890123456789']
for fqdn in fqdns:
val, msg = self.validator.Validate(fqdn)
val, msg = self.validator.Validate(fqdn, {})
self.assertFalse(val)
self.assertEqual(msg, f"'{fqdn}' is an invalid Grr host ID.")

Expand Down Expand Up @@ -443,7 +443,7 @@ def test_ValidateSuccess(self):
'https://grr.ramoj-playground.internal',
]
for fqdn in fqdns:
val, _ = self.validator.Validate(fqdn)
val, _ = self.validator.Validate(fqdn, {})
self.assertTrue(val, f'{fqdn} failed validation')

val, _ = self.validator.Validate(','.join(fqdns), {'comma_separated': True})
Expand All @@ -457,7 +457,7 @@ def test_ValidationFailure(self):
'http://one.*.com'
]
for fqdn in fqdns:
val, msg = self.validator.Validate(fqdn)
val, msg = self.validator.Validate(fqdn, {})
self.assertFalse(val)
self.assertEqual(msg, f"'{fqdn}' is an invalid URL.")

Expand Down Expand Up @@ -521,13 +521,13 @@ def test_Validation(self):

def test_DefaultValidation(self):
"""Tests param validation with DefaultValidator."""
val, _ = self.vm.Validate('operand')
val, _ = self.vm.Validate('operand', {})
self.assertTrue(val)

def test_ValidationFailure(self):
"""Tests validation failure."""
val, msg = self.vm.Validate('invalid',
{'format': 'subnet', 'comma_separated': False})
{'format': 'subnet', 'comma_separated': False})
self.assertFalse(val)
self.assertEqual(msg, 'invalid is not a valid subnet.')

Expand Down

0 comments on commit 9ff3ba8

Please sign in to comment.