Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 4, 2024
1 parent bc6ae29 commit 108e67e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 49 deletions.
3 changes: 3 additions & 0 deletions cobaya/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from cobaya.log import LoggedError, get_logger
from cobaya.parameterization import expand_info_param
from cobaya import mpi
import cobaya.typing

# Logger
logger = get_logger(__name__)
Expand Down Expand Up @@ -141,6 +142,8 @@ def load_info_overrides(*infos_or_yaml_or_files, **flags) -> InputDict:
for flag, value in flags.items():
if value is not None:
info[flag] = value
if cobaya.typing.enforce_type_checking:
cobaya.typing.validate_type(InputDict, info)
return info


Expand Down
30 changes: 16 additions & 14 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ParamDict(TypedDict, total=False):
value: Union[float, Callable, str]
derived: Union[bool, str, Callable]
prior: Union[None, Sequence[float], SciPyDistDict, SciPyMinMaxDict]
ref: Union[None, Sequence[float], SciPyDistDict, SciPyMinMaxDict]
ref: Union[None, float, Sequence[float], SciPyDistDict, SciPyMinMaxDict]
proposal: Optional[float]
renames: Union[str, Sequence[str]]
latex: str
Expand Down Expand Up @@ -129,7 +129,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
if expected_type is int:
if not (value in (np.inf, -np.inf) or isinstance(value, numbers.Integral)):
raise TypeError(
f"{curr_path} must be an integer or infinity, got {type(value).__name__}"
f"{curr_path} must be an integer, got {type(value).__name__}"
)
return

Expand All @@ -140,14 +140,13 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
return

if expected_type is bool:
if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)):
if not isinstance(value, bool):
# if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)):
raise TypeError(
f"{curr_path} must be boolean, got {type(value).__name__}"
)
return

# special case for Cobaya

if sys.version_info < (3, 10):
from typing_extensions import is_typeddict
else:
Expand All @@ -163,7 +162,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
f"'{expected_type.__name__}': {invalid_keys}")
for key, val in value.items():
validate_type(type_hints[key], val, f"{path}.{key}" if path else str(key))
return True
return

if (origin := typing.get_origin(expected_type)) and (
args := typing.get_args(expected_type)):
Expand All @@ -178,6 +177,8 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
return validate_type(t, value, path)
except TypeError as e:
error_msg = str(e)
if ' any Union type' in error_msg:
raise
error_path = error_msg.split(' ')[0].strip("'")

# If error is about the current path, it's a structural error
Expand Down Expand Up @@ -210,10 +211,12 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
if origin is typing.ClassVar:
return validate_type(args[0], value, path)

if isinstance(value, Mapping) != issubclass(origin, Mapping):
raise TypeError(
f"{curr_path} must be {args[0]}, got {type(value).__name__}"
)

if issubclass(origin, Mapping):
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping, "
f"got {type(value).__name__}")
for k, v in value.items():
key_path = f"{path}[{k!r}]" if path else f"[{k!r}]"
validate_type(args[0], k, f"{key_path} (key)")
Expand All @@ -231,12 +234,11 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
)
return

if not isinstance(value, Iterable):
raise TypeError(
f"{curr_path} must be iterable, got {type(value).__name__}"
)

if len(args) == 1:
if not isinstance(value, Iterable):
raise TypeError(
f"{curr_path} must be iterable, got {type(value).__name__}"
)
for i, item in enumerate(value):
validate_type(args[0], item, f"{path}[{i}]" if path else f"[{i}]")
else:
Expand Down
36 changes: 1 addition & 35 deletions tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,8 @@
import pytest

from cobaya.component import CobayaComponent
from cobaya.likelihood import Likelihood
from cobaya.tools import NumberWithUnits
from cobaya.typing import InputDict, ParamDict, Sequence
from cobaya.run import run


class GenericLike(Likelihood):
any: Any
classvar: ClassVar[int] = 1
infinity: int = float("inf")
mean: NumberWithUnits = 1
noise: float = 0
none: int = None
numpy_int: int = np.int64(1)
optional: Optional[int] = None
paramdict_params: ParamDict = {"prior": [0.0, 1.0]}
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]}
tuple_params: Tuple[float, float] = (0.0, 1.0)

_enforce_types = True

def logp(self, **params_values):
return 1


def test_sampler_types():
original_info: InputDict = {
"likelihood": {"like": GenericLike},
"sampler": {"mcmc": {"max_samples": 1}},
}
_ = run(original_info)

info = original_info.copy()
info["sampler"]["mcmc"]["max_samples"] = "not_an_int"
with pytest.raises(TypeError):
run(info)
from cobaya.typing import ParamDict, Sequence


class GenericComponent(CobayaComponent):
Expand Down

0 comments on commit 108e67e

Please sign in to comment.