diff --git a/docs/source/deep_dives/configs.rst b/docs/source/deep_dives/configs.rst index 0f86a29a58..9025896b54 100644 --- a/docs/source/deep_dives/configs.rst +++ b/docs/source/deep_dives/configs.rst @@ -89,8 +89,8 @@ this by taking a look at the :func:`~torchtune.config.instantiate` API. def instantiate( config: DictConfig, - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) :func:`~torchtune.config.instantiate` also accepts positional arguments diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py index 279fe22629..3a4d4f635d 100644 --- a/torchtune/config/_instantiate.py +++ b/torchtune/config/_instantiate.py @@ -18,11 +18,11 @@ def _create_component( _component_: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any], -): +) -> Any: return _component_(*args, **kwargs) -def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]): +def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any: """ Creates the object specified in _component_ field with provided positional args and kwargs already merged. Raises an InstantiationError if _component_ is not specified. @@ -40,8 +40,8 @@ def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]): def instantiate( config: DictConfig, - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> Any: """ Given a DictConfig with a _component_ field specifying the object to instantiate and @@ -60,8 +60,8 @@ def instantiate( config (DictConfig): a single field in the OmegaConf object parsed from the yaml file. This is expected to have a _component_ field specifying the path of the object to instantiate. - *args (Tuple[Any, ...]): positional arguments to pass to the object to instantiate. - **kwargs (Dict[str, Any]): keyword arguments to pass to the object to instantiate. + *args (Any): positional arguments to pass to the object to instantiate. + **kwargs (Any): keyword arguments to pass to the object to instantiate. Examples: >>> config.yaml: diff --git a/torchtune/config/_parse.py b/torchtune/config/_parse.py index 8472094f0b..5a8e762333 100644 --- a/torchtune/config/_parse.py +++ b/torchtune/config/_parse.py @@ -65,7 +65,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]: return namespace, unknown_args -def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]: +def parse(recipe_main: Recipe) -> Callable[..., Any]: """ Decorator that handles parsing the config file and CLI overrides for a recipe. Use it on the recipe's main function. @@ -83,7 +83,7 @@ def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]: >>> tune my_recipe --config config.yaml foo=bar Returns: - Callable[[Recipe], Any]: the decorated main + Callable[..., Any]: the decorated main """ @functools.wraps(recipe_main) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 055252cf72..db8a3837eb 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -9,7 +9,7 @@ import sys from collections import OrderedDict from functools import partial -from typing import Any, Dict, Generator, Optional, Tuple +from typing import Any, Dict, Generator, Optional from warnings import warn import torch @@ -24,10 +24,10 @@ def reparametrize_as_dtype_state_dict_post_hook( model: nn.Module, state_dict: Dict[str, Any], - *args: Tuple[Any, ...], + *args: Any, dtype: torch.dtype = torch.bfloat16, offload_to_cpu: bool = True, - **kwargs: Dict[Any, Any], + **kwargs: Any, ): """ A state_dict hook that replaces NF4 tensors with their restored @@ -47,10 +47,10 @@ def reparametrize_as_dtype_state_dict_post_hook( Args: model (nn.Module): the model to take ``state_dict()`` on state_dict (Dict[str, Any]): the state dict to modify - *args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook. + *args (Any): Unused args passed when running this as a state_dict hook. dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. - **kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook. + **kwargs (Any): Unused keyword args passed when running this as a state_dict hook. """ for k, v in state_dict.items(): if isinstance(v, NF4Tensor): @@ -62,10 +62,10 @@ def reparametrize_as_dtype_state_dict_post_hook( def _low_ram_reparametrize_as_dtype_state_dict_post_hook( model: nn.Module, state_dict: Dict[str, Any], - *args: Tuple[Any, ...], + *args: Any, dtype: torch.dtype = torch.bfloat16, offload_to_cpu: bool = True, - **kwargs: Dict[Any, Any], + **kwargs: Any, ): """ A state_dict hook that replaces NF4 tensors with their restored @@ -88,10 +88,10 @@ def _low_ram_reparametrize_as_dtype_state_dict_post_hook( Args: model (nn.Module): the model to take ``state_dict()`` on state_dict (Dict[str, Any]): the state dict to modify - *args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook. + *args (Any): Unused args passed when running this as a state_dict hook. dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. - **kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook. + **kwargs (Any): Unused keyword args passed when running this as a state_dict hook. """ # Create a state dict of FakeTensors that matches the state_dict mode = FakeTensorMode()