Skip to content

Commit

Permalink
feat: return a default value if key is missing in image.get
Browse files Browse the repository at this point in the history
This copies `deep_get` and some tests from Taskgraph. It's being
duplicated to avoid depending on Taskgraph.
  • Loading branch information
ahal committed Jan 6, 2025
1 parent 1c3d7b4 commit adf7946
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/ciadmin/generate/ciconfig/worker_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.

from typing import Any

import attr

from .get import get_ciconfig_file
from ciadmin.generate.ciconfig.get import get_ciconfig_file
from ciadmin.util.templates import deep_get


@attr.s(frozen=True)
Expand All @@ -31,19 +34,32 @@ def mk(image_name, info):
[mk(image_name, info) for image_name, info in worker_images.items()]
)

def get(self, cloud, *keys):
def get(self, cloud: str, key: str | None=None, default: Any|None=None) -> Any:
"""
Look up a key under the given cloud for this worker image.
Look up a key under the given cloud config for this worker image.
Args:
cloud (str): The cloud provider (provider_id) to obtain data from.
key (str): The key to obtain a value from (optional).
If not specified then the entire value of the specified cloud is returned. If
specified, the value of the matching key will be obtained. This can optionally
use dot path notation (e.g "key.subkey") to obtain a value from nested
dictionaries. If the key or any nested subkey along the dot path does not exist,
`None` is returned.
Returns:
Any: The value defined under the specified cloud.
"""
if cloud not in self.clouds:
raise KeyError(
f"{cloud} not present for {self.image_name} - "
"maybe you need to update worker-images.yml?"
)
v = self.clouds[cloud]
for k in keys:
v = v[k]
return v
cfg = self.clouds[cloud]
if not key:
return cfg

return deep_get(cfg, key, default)


class WorkerImages:
Expand Down
23 changes: 23 additions & 0 deletions src/ciadmin/util/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import copy
from typing import Any


def merge_to(source, dest):
Expand Down Expand Up @@ -48,3 +49,25 @@ def merge(*objects):
if len(objects) == 1:
return copy.deepcopy(objects[0])
return merge_to(objects[-1], merge(*objects[:-1]))


def deep_get(dict_: dict[str, Any], field: str, default: Any|None=None) -> Any:
"""
Return a key from nested dictionaries using dot path notation
(e.g "key.subkey").
Args:
dict_: The dictionary to retrieve a value from.
field: The key to retrieve, can use dot path notation.
default: A default value to return if key does not exist
(default: None).
"""
container, subfield = dict_, field
while "." in subfield:
f, subfield = subfield.split(".", 1)
if f not in container:
return default

container = container[f]

return container.get(subfield, default)
73 changes: 73 additions & 0 deletions tests/ciadmin/test_util_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.

import pytest

from ciadmin.util.templates import deep_get, merge, merge_to

print(__file__)


def test_merge_to_dicts():
source = {"a": 1, "b": 2}
dest = {"b": "20", "c": 30}
expected = {
"a": 1, # source only
"b": 2, # source overrides dest
"c": 30, # dest only
}
assert merge_to(source, dest) == expected
assert dest == expected


def test_merge_to_lists():
source = {"x": [3, 4]}
dest = {"x": [1, 2]}
expected = {"x": [1, 2, 3, 4]} # dest first
assert merge_to(source, dest) == expected
assert dest == expected


def test_merge_diff_types():
source = {"x": [1, 2]}
dest = {"x": "abc"}
expected = {"x": [1, 2]} # source wins
assert merge_to(source, dest) == expected
assert dest == expected


def test_merge():
first = {"a": 1, "b": 2, "d": 11}
second = {"b": 20, "c": 30}
third = {"c": 300, "d": 400}
expected = {
"a": 1,
"b": 20,
"c": 300,
"d": 400,
}
assert merge(first, second, third) == expected

# inputs haven't changed..
assert first == {"a": 1, "b": 2, "d": 11}
assert second == {"b": 20, "c": 30}
assert third == {"c": 300, "d": 400}


@pytest.mark.parametrize(
"args,expected",
(
pytest.param(({}, "foo"), None, id="not found"),
pytest.param(({}, "foo", True), True, id="not found default"),
pytest.param(({"foo": "bar"}, "foo"), "bar", id="single"),
pytest.param(({"foo": {"bar": {"baz": 1}}}, "foo.bar.baz"), 1, id="dot path"),
pytest.param(
({"foo": {"bar": {"baz": 1}}}, "foo.missing.baz"),
None,
id="not found middle",
),
),
)
def test_deep_get(args, expected):
assert deep_get(*args) == expected

0 comments on commit adf7946

Please sign in to comment.