diff --git a/CHANGELOG.md b/CHANGELOG.md index 646c105..a13b65d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.0] - 03/07/2023 + +### Added + +- Added `max_recursion_depth` argument to `CascadeConfig` to limit the depth of + hierarchically updating nested dictionaries. When the maximum nesting depth is + exceeded, the new dictionary will be used as-is, overwriting any previous + values under that dictionary tree. + ## [0.3.1] - 03/07/2023 ### Fixed diff --git a/cascade_config.py b/cascade_config.py index 15b9ebf..cff07ec 100644 --- a/cascade_config.py +++ b/cascade_config.py @@ -1,6 +1,6 @@ """Cascading configuration from the CLI and config files.""" -__version__ = "0.3.1" +__version__ = "0.4.0" import json import os @@ -14,7 +14,12 @@ class CascadeConfig: """Cascading configuration.""" - def __init__(self, validation_schema=None, none_overrides_value=False): + def __init__( + self, + validation_schema=None, + none_overrides_value=False, + max_recursion_depth=None, + ): """ Cascading configuration. @@ -25,6 +30,10 @@ def __init__(self, validation_schema=None, none_overrides_value=False): none_overrides_value: bool If True, a None value overrides a not-None value from the previous configuration. If False, None values will never override not-None values. + max_recursion_depth: int, optional + Maximum depth of nested dictionaries to recurse into. When the maximum recursion depth + is reached, the nested dictionary will be replaced by the newer nested dictionary. If + None, recurse into all nested dictionaries. Examples -------- @@ -36,6 +45,7 @@ def __init__(self, validation_schema=None, none_overrides_value=False): """ self.validation_schema = validation_schema self.none_overrides_value = none_overrides_value + self.max_recursion_depth = max_recursion_depth self.sources = [] @property @@ -51,21 +61,28 @@ def validation_schema(self, value): else: self._validation_schema = None - def _update_dict_recursively(self, original: Dict, updater: Dict) -> Dict: + def _update_dict_recursively(self, original: Dict, updater: Dict, depth: int) -> Dict: """Update dictionary recursively.""" + depth = depth + 1 for k, v in updater.items(): if isinstance(v, dict): - if not v: # v is not None, v is empty dictionary + if not v: + # v is an empty dictionary original[k] = dict() + elif self.max_recursion_depth and depth > self.max_recursion_depth: + # v is a populated dictionary, exceeds max depth + original[k] = v else: - original[k] = self._update_dict_recursively(original.get(k, {}), v) + # v is a populated dictionary, can be further recursed + original[k] = self._update_dict_recursively(original.get(k, {}), v, depth) elif isinstance(v, bool): - original[k] = v # v is True or False - elif v or k not in original: # v is not None, or key does not exist yet + # v is True or False original[k] = v - elif ( - self.none_overrides_value - ): # v is None, but can override previous value + elif v or k not in original: + # v is thruthy (therefore not None), or key does not exist yet + original[k] = v + elif self.none_overrides_value: + # v is None, but can override previous value original[k] = v return original @@ -114,7 +131,7 @@ def parse(self) -> Dict: """Parse all sources, cascade, validate, and return cascaded configuration.""" config = dict() for source in self.sources: - config = self._update_dict_recursively(config, source.load()) + config = self._update_dict_recursively(config, source.load(), depth=0) if self.validation_schema: jsonschema.validate(config, self.validation_schema.load()) @@ -196,9 +213,7 @@ class JSONConfigSource(_ConfigSource): def _read(self) -> Dict: if not isinstance(self.source, (str, os.PathLike)): - raise TypeError( - "JSONConfigSource `source` must be a string or path-like object" - ) + raise TypeError("JSONConfigSource `source` must be a string or path-like object") with open(self.source, "rt") as json_file: config = json.load(json_file) return config @@ -221,9 +236,7 @@ class NamespaceConfigSource(_ConfigSource): def _read(self) -> Dict: if not isinstance(self.source, Namespace): - raise TypeError( - "NamespaceConfigSource `source` must be an argparse.Namespace object" - ) + raise TypeError("NamespaceConfigSource `source` must be an argparse.Namespace object") config = vars(self.source) return config @@ -256,7 +269,5 @@ def load(self) -> Dict: elif isinstance(self.source, Dict): schema = self.source else: - raise TypeError( - "ValidationSchema `source` must be of type string, path-like, or dict" - ) + raise TypeError("ValidationSchema `source` must be of type string, path-like, or dict") return schema diff --git a/pyproject.toml b/pyproject.toml index 2595bdd..eac2bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,3 +49,12 @@ docs = [ [tool.flit.module] name = "cascade_config" + +[tool.black] +line-length = 99 +target-version = ['py38'] + +[tool.ruff] +line-length = 99 +target-version = 'py38' + diff --git a/tests/test_cascade_config.py b/tests/test_cascade_config.py index 72a022b..1961817 100644 --- a/tests/test_cascade_config.py +++ b/tests/test_cascade_config.py @@ -2,15 +2,12 @@ import argparse import json -from os import stat import tempfile -from typing import Type -import pytest import jsonschema +import pytest import cascade_config -from cascade_config import ValidationSchema TEST_SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -26,29 +23,29 @@ }, "log_level": { "type": "string", - "enum": ["debug", "info", "warning", "error", "critical"] - } - } + "enum": ["debug", "info", "warning", "error", "critical"], + }, + }, } - } + }, } -TEST_SAMPLE = { - "config_example": {"num_cpu": 1, "log_level": "info"} -} -TEST_SAMPLE_2 = { - "config_example": {"log_level": "debug"}, "test": True -} -TEST_SAMPLE_CASC = { - "config_example": {"num_cpu": 1, "log_level": "debug"}, "test": True +TEST_SAMPLE = {"config_example": {"num_cpu": 1, "log_level": "info"}} +TEST_SAMPLE_2 = {"config_example": {"log_level": "debug"}, "test": True} +TEST_SAMPLE_CASC = {"config_example": {"num_cpu": 1, "log_level": "debug"}, "test": True} +TEST_SAMPLE_INVALID = {"config_example": {"num_cpu": "not_a_number", "log_level": "info"}} +TEST_SAMPLE_NESTED_A = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3a": True}}} +TEST_SAMPLE_NESTED_B = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3b": False}}} +TEST_SAMPLE_NESTED_RESULT_NOMAX = { + "config_example": {"num_cpu": 1, "depth_2": {"depth_3a": True, "depth_3b": False}} } +TEST_SAMPLE_NESTED_RESULT_MAX1 = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3b": False}}} -TEST_SAMPLE_INVALID = { - "config_example": {"num_cpu": "not_a_number", "log_level": "info"} -} def get_sample_namespace(test_sample): - flatten = lambda l: [item for sublist in l for item in sublist] + def flatten(lst): + return [item for sublist in lst for item in sublist] + test_args = flatten([[f"--{i[0]}", f"{i[1]}"] for i in test_sample.items()]) parser = argparse.ArgumentParser() parser.add_argument("--num_cpu", type=int) @@ -61,7 +58,7 @@ class TestCascadeConfig: @staticmethod def get_json_file(test_sample): - with tempfile.NamedTemporaryFile(mode='wt', delete=False) as json_file: + with tempfile.NamedTemporaryFile(mode="wt", delete=False) as json_file: json.dump(test_sample, json_file) json_file.seek(0) json_file_name = json_file.name @@ -69,7 +66,9 @@ def get_json_file(test_sample): @staticmethod def get_sample_namespace(test_sample): - flatten = lambda l: [item for sublist in l for item in sublist] + def flatten(lst): + return [item for sublist in lst for item in sublist] + test_args = flatten([[f"--{i[0]}", f"{i[1]}"] for i in test_sample.items()]) parser = argparse.ArgumentParser() parser.add_argument("--num_cpu", type=int) @@ -105,9 +104,7 @@ def test_single_config_namespace(self): subkey = "config_example" cc = cascade_config.CascadeConfig() cc.add_namespace( - get_sample_namespace(TEST_SAMPLE[subkey]), - subkey=subkey, - validation_schema=TEST_SCHEMA + get_sample_namespace(TEST_SAMPLE[subkey]), subkey=subkey, validation_schema=TEST_SCHEMA ) assert cc.parse() == TEST_SAMPLE @@ -143,6 +140,18 @@ def test_multiple_configs(self): cc.add_json(self.get_json_file(TEST_SAMPLE_2)) assert cc.parse() == TEST_SAMPLE_CASC + def test_max_recursion(self): + """Test max_recursion_depth argument.""" + cc = cascade_config.CascadeConfig(max_recursion_depth=None) + cc.add_dict(TEST_SAMPLE_NESTED_A) + cc.add_dict(TEST_SAMPLE_NESTED_B) + assert cc.parse() == TEST_SAMPLE_NESTED_RESULT_NOMAX + + cc = cascade_config.CascadeConfig(max_recursion_depth=1) + cc.add_dict(TEST_SAMPLE_NESTED_A) + cc.add_dict(TEST_SAMPLE_NESTED_B) + assert cc.parse() == TEST_SAMPLE_NESTED_RESULT_MAX1 + def test_validation_schema_from_object(self): with pytest.raises(TypeError): cascade_config.ValidationSchema.from_object(42)