Skip to content

Commit

Permalink
trivial
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 6, 2024
1 parent 3642d4f commit c76a95f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions cobaya/theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, info: TheoryDictIn = empty_dict,
standalone=standalone)

# set to Provider instance before calculations
self.provider: Any = None
self.provider: Optional['Provider'] = None
# Generate cache states, to avoid recomputing.
# Default 3, but can be changed by sampler
self.set_cache_size(3)
Expand Down Expand Up @@ -415,7 +415,7 @@ def __init__(self, model, requirement_providers: Dict[str, Theory]):
self.requirement_providers = requirement_providers
self.params = {}

def set_current_input_params(self, params):
def set_current_input_params(self, params: ParamValuesDict):
self.params = params

def get_param(self, param: Union[str, Iterable[str]]) -> Union[float, List[float]]:
Expand Down
4 changes: 2 additions & 2 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):

if expected_type is float:
if not (isinstance(value, numbers.Real) or
(isinstance(value, np.ndarray) and value.shape == ())):
(isinstance(value, np.ndarray) and value.ndim == 0)):
raise TypeError(f"{curr_path} must be a float, got {type(value).__name__}")
return

Expand Down Expand Up @@ -218,7 +218,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):

if issubclass(origin, Iterable):
if isinstance(value, np.ndarray):
if not value.shape:
if value.ndim == 0:
raise TypeError(f"{curr_path} numpy array zero rank")
if len(args) == 1 and not np.issubdtype(value.dtype, args[0]):
raise TypeError(
Expand Down

0 comments on commit c76a95f

Please sign in to comment.