Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 4, 2024
1 parent a65605e commit 3bdbf5a
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,6 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
return

# special case for Cobaya
if expected_type.__name__ == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return

if sys.version_info < (3, 10):
from typing_extensions import is_typeddict
Expand Down Expand Up @@ -214,7 +209,17 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
if origin is typing.ClassVar:
return validate_type(args[0], value, path)

if origin in (list, tuple, set, Sequence, Iterable, np.ndarray):
if origin in (dict, Mapping):
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping, "
f"got {type(value).__name__}")
for k, v in value.items():
key_path = f"{path}[{k!r}]" if path else f"[{k!r}]"
validate_type(args[0], k, f"{key_path} (key)")
validate_type(args[1], v, key_path)
return

if issubclass(origin, Iterable):
if isinstance(value, np.ndarray):
if not value.shape:
raise TypeError(f"{curr_path} numpy array zero rank")
Expand Down Expand Up @@ -244,16 +249,16 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
validate_type(t, v, f"{path}[{i}]" if path else f"[{i}]")
return

if origin in (dict, Mapping):
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping, "
f"got {type(value).__name__}")
for k, v in value.items():
key_path = f"{path}[{k!r}]" if path else f"[{k!r}]"
validate_type(args[0], k, f"{key_path} (key)")
validate_type(args[1], v, key_path)
return
if not (isinstance(value, expected_type) or
expected_type is Sequence and isinstance(value, np.ndarray)):

# special case for Cobaya's NumberWithUnits, if not instance yet
if getattr(expected_type, "__name__", "") == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(
f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return

raise TypeError(f"{curr_path} must be of type {expected_type.__name__}, "
f"got {type(value).__name__}")

0 comments on commit 3bdbf5a

Please sign in to comment.