diff --git a/src/mlopus/utils/dicts.py b/src/mlopus/utils/dicts.py index dfee024..b050c0a 100644 --- a/src/mlopus/utils/dicts.py +++ b/src/mlopus/utils/dicts.py @@ -94,12 +94,16 @@ def deep_merge(*dicts: dict): """Merge dicts at the level of leaf-values.""" retval = {} - def _update(tgt: dict, src: dict, prefix_keys: List[str]): + def _update(tgt: dict, src: Mapping, prefix_keys: List[str]): for key, val in src.items(): - if isinstance(val, dict): - _update(tgt, val, prefix_keys + [key]) + _key = prefix_keys + [key] + + if isinstance(val, Mapping) and (val or isinstance(get_nested(tgt, _key, None), Mapping)): + # Treat value as nested if it's a non-empty dict or if the target is already nested + _update(tgt, val, _key) else: - set_nested(tgt, prefix_keys + [key], deepcopy(val)) + # Treat value as a leaf (scalar) otherwise + set_nested(tgt, _key, deepcopy(val)) for _dict in dicts: _update(retval, _dict, prefix_keys=[])