diff --git a/src/resp_ode/config.py b/src/resp_ode/config.py index 9950eac..c52391e 100644 --- a/src/resp_ode/config.py +++ b/src/resp_ode/config.py @@ -107,14 +107,31 @@ def assert_valid_configuration(self): validator_funcs = make_list_if_not(validator_funcs) vals = [getattr(self, k) for k in key] # val_func() throws assert errors if incongruence arrises - [ - ( - val_func(key[0], vals[0]) - if len(key) == 1 # convert back to floats if needed - else val_func(key, vals) - ) - for val_func in validator_funcs - ] + try: + [ + ( + val_func(key[0], vals[0]) + if len(key) + == 1 # convert back to floats if needed + else val_func(key, vals) + ) + for val_func in validator_funcs + ] + except Exception as e: + if len(key) > 1: + err_text = """There was an issue validating your Config object. + The error was caused by the intersection of the following parameters: %s. + %s""" % ( + key, + e, + ) + else: + err_text = """The following error occured while validating the %s + parameter in your configuration file: %s""" % ( + key[0], + e, + ) + raise ConfigValidationError(err_text) def make_list_if_not(obj): @@ -291,6 +308,16 @@ def test_positive(key, value): ) +def test_enum_len(key, enum, expected_len): + assert ( + len(enum) == expected_len + ), "Expected %s to have %s entries, got %s" % ( + key, + expected_len, + len(enum), + ) + + def test_not_negative(key, value): """ checks if a value is not negative. @@ -654,6 +681,13 @@ class is accepted to modify/create the downstream parameters. key, [(vals[0], 2 ** vals[0]), vals[1]] ), }, + { + "name": ["NUM_STRAINS", "STRAIN_IDX"], + # check that len(STRAIN_IDX)==NUM_STRAINS + "validate": lambda keys, vals: test_enum_len( + keys[1], vals[1], vals[0] + ), + }, { "name": "MAX_VACCINATION_COUNT", "validate": test_not_negative, @@ -852,4 +886,14 @@ class is accepted to modify/create the downstream parameters. class ConfigParserError(Exception): + """A basic class meant to denote when the Config + class is having an issue parsing a configuration file""" + + pass + + +class ConfigValidationError(Exception): + """A basic class meant to denote when the Config + class is having an issue validating a configuration file""" + pass diff --git a/tests/test_config.py b/tests/test_config.py index 6ab94c7..7134f5f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ import numpyro.distributions as dist import pytest -from resp_ode.config import Config, ConfigParserError +from resp_ode.config import Config, ConfigParserError, ConfigValidationError GLOBAL_TEST_CONFIG = "tests/test_config_global.json" PATH_VARIABLES = [ @@ -38,7 +38,7 @@ def test_valid_path_variables(): def test_invalid_type_path_variables(): for path_var in PATH_VARIABLES: example_input_json = """{"%s":%d}""" % (path_var, 10) - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(example_input_json) for path_var in PATH_VARIABLES: example_input_json = """{"%s":"%s"}""" % ( @@ -46,55 +46,55 @@ def test_invalid_type_path_variables(): "some_random_incorrect_path.json", ) print(example_input_json) - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(example_input_json) def test_non_ascending_age_limits(): input_json = """{"AGE_LIMITS": [10, 1, 50, 60]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_out_of_bounds_age_limits(): input_json = """{"AGE_LIMITS": [0, 18, 50, 64, 95]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_negative_age_limits(): input_json = """{"AGE_LIMITS": [-10, 18, 50, 64]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_float_ages(): input_json = """{"AGE_LIMITS": [0, 5.5, 18, 50, 64]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_amplitude_type(): input_json = """{"SEASONALITY_AMPLITUDE": [0]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_amplitude_val(): input_json = """{"SEASONALITY_AMPLITUDE": 4.0}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_vax_age_coefs_type(): input_json = """{"AGE_DOSE_SPECIFIC_VAX_COEF": "blah"}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_vax_path(): input_json = """{"VACCINATION_MODEL_DATA": "blah"}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -105,7 +105,7 @@ def test_invalid_vax_age_coefs_vals(): "NUM_AGE_GROUPS":3, "MAX_VACCINATION_COUNT": 2 }""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -123,7 +123,7 @@ def test_valid_vax_age_coefs(): def test_invalid_seasonality_amplitude_val_negative(): input_json = """{"SEASONALITY_AMPLITUDE": -4.0}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -135,7 +135,7 @@ def test_invalid_seasonality_amplitude_dist(): "loc": 0 } }}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -157,19 +157,19 @@ def test_valid_seasonality_amplitude_dist(): def test_invalid_seasonality_second_wave_type(): input_json = """{"SEASONALITY_SECOND_WAVE": [0]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_second_wave_val(): input_json = """{"SEASONALITY_SECOND_WAVE": 1.5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_second_wave_val_negative(): input_json = """{"SEASONALITY_SECOND_WAVE": -1.5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -181,7 +181,7 @@ def test_invalid_seasonalit_second_wave_dist(): "loc": 0 } }}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -203,19 +203,19 @@ def test_valid_seasonality_second_wave_dist(): def test_invalid_seasonality_shift_type(): input_json = """{"SEASONALITY_SHIFT": [0]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_shift_val(): input_json = """{"SEASONALITY_SHIFT": 183}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_seasonality_shift_val_negative(): input_json = """{"SEASONALITY_SHIFT": -183}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -227,7 +227,7 @@ def test_invalid_seasonalit_shift_dist(): "loc": 0 } }}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -254,13 +254,13 @@ def test_valid_seasonality_shift_dist(): def test_invalid_introduction_perc_type(): input_json = """{"INTRODUCTION_PCTS": 0.1}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_introduction_perc_val(): input_json = """{"INTRODUCTION_PCTS":[-1]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -271,13 +271,13 @@ def test_valid_introduction_perc(): def test_invalid_introduction_times_type(): input_json = """{"INTRODUCTION_TIMES": 0}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_introduction_times_val(): input_json = """{"INTRODUCTION_TIMES": [-1]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -288,13 +288,13 @@ def test_valid_introduction_times_val(): def test_invalid_introduction_scale_type(): input_json = """{"INTRODUCTION_SCALES": 5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_introduction_scale_val(): input_json = """{"INTRODUCTION_SCALES": [-1]}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -310,13 +310,13 @@ def test_valid_age_limits(): def test_negative_population_size(): input_json = """{"POP_SIZE": -1}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_str_population_size(): input_json = """{"POP_SIZE": "5"}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -327,31 +327,31 @@ def test_valid_population_size(): def test_negative_initial_infections(): input_json = """{"INITIAL_INFECTIONS": -5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_negative_initial_infections_scale(): input_json = """{"INITIAL_INFECTIONS_SCALE": -1.2}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_negative_tree_depth(): input_json = """{"MAX_TREE_DEPTH": -1}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_zero_tree_depth(): input_json = """{"MAX_TREE_DEPTH": 0}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_float_tree_depth(): input_json = """{"MAX_TREE_DEPTH": 1.2}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -362,7 +362,7 @@ def test_valid_tree_depth(): def test_str_initial_infections(): input_json = """{"INITIAL_INFECTIONS": "5"}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -378,10 +378,28 @@ def test_valid_initial_infections_float(): def test_init_infections_greater_than_pop_size(): input_json = """{"POP_SIZE": 1, "INITIAL_INFECTIONS": 5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) +def test_strain_idx_shorter_than_num_strains(): + input_json = """{"NUM_STRAINS": 3, "STRAIN_IDX": ["x", "y"]}""" + with pytest.raises(ConfigValidationError): + Config(input_json) + + +def test_strain_idx_longer_than_num_strains(): + input_json = """{"NUM_STRAINS": 2, "STRAIN_IDX": ["x", "y", "z"]}""" + with pytest.raises(ConfigValidationError): + Config(input_json) + + +def test_strain_idx_equal_to_num_strains(): + input_json = """{"NUM_STRAINS": 2, "STRAIN_IDX": ["x", "y"]}""" + c = Config(input_json) + assert c.NUM_STRAINS == 2 and len(c.STRAIN_IDX) == 2 + + def test_init_infections_less_than_pop_size(): input_json = """{"POP_SIZE": 5, "INITIAL_INFECTIONS": 1}""" c = Config(input_json) @@ -396,13 +414,13 @@ def test_valid_infectious_period(): def test_negative_infectious_period(): input_json = """{"INFECTIOUS_PERIOD": -5}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) def test_invalid_support_infectious_period(): input_json = """{"INFECTIOUS_PERIOD": {"distribution": "Normal", "params": {"loc": 0, "scale":1}}}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -414,7 +432,7 @@ def test_valid_support_infectious_period(): def test_invalid_step_size(): input_json = """{"CONSTANT_STEP_SIZE": -1.0}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json) @@ -515,7 +533,7 @@ def test_invalid_support_nested_distribution_infectious_period(): } } }}""" - with pytest.raises(AssertionError): + with pytest.raises(ConfigValidationError): Config(input_json)