Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nit: Fix/add some type annotations #1982

Merged
merged 8 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/deep_dives/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ this by taking a look at the :func:`~torchtune.config.instantiate` API.

def instantiate(
config: DictConfig,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
)

:func:`~torchtune.config.instantiate` also accepts positional arguments
Expand Down
12 changes: 6 additions & 6 deletions torchtune/config/_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def _create_component(
_component_: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
bradhilton marked this conversation as resolved.
Show resolved Hide resolved
):
) -> Any:
return _component_(*args, **kwargs)


def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]):
def _instantiate_node(node: Dict[str, Any], *args: Any) -> Any:
"""
Creates the object specified in _component_ field with provided positional args
and kwargs already merged. Raises an InstantiationError if _component_ is not specified.
Expand All @@ -40,8 +40,8 @@ def _instantiate_node(node: Dict[str, Any], *args: Tuple[Any, ...]):

def instantiate(
config: DictConfig,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
"""
Given a DictConfig with a _component_ field specifying the object to instantiate and
Expand All @@ -60,8 +60,8 @@ def instantiate(
config (DictConfig): a single field in the OmegaConf object parsed from the yaml file.
This is expected to have a _component_ field specifying the path of the object
to instantiate.
*args (Tuple[Any, ...]): positional arguments to pass to the object to instantiate.
**kwargs (Dict[str, Any]): keyword arguments to pass to the object to instantiate.
*args (Any): positional arguments to pass to the object to instantiate.
**kwargs (Any): keyword arguments to pass to the object to instantiate.

Examples:
>>> config.yaml:
Expand Down
4 changes: 2 additions & 2 deletions torchtune/config/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]:
return namespace, unknown_args


def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]:
def parse(recipe_main: Recipe) -> Callable[..., Any]:
"""
Decorator that handles parsing the config file and CLI overrides
for a recipe. Use it on the recipe's main function.
Expand All @@ -83,7 +83,7 @@ def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]:
>>> tune my_recipe --config config.yaml foo=bar

Returns:
Callable[[Recipe], Any]: the decorated main
Callable[..., Any]: the decorated main
"""

@functools.wraps(recipe_main)
Expand Down
18 changes: 9 additions & 9 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Generator, Optional, Tuple
from typing import Any, Dict, Generator, Optional
from warnings import warn

import torch
Expand All @@ -24,10 +24,10 @@
def reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
*args: Any,
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
**kwargs: Any,
):
"""
A state_dict hook that replaces NF4 tensors with their restored
Expand All @@ -47,10 +47,10 @@ def reparametrize_as_dtype_state_dict_post_hook(
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
*args (Any): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
**kwargs (Any): Unused keyword args passed when running this as a state_dict hook.
"""
for k, v in state_dict.items():
if isinstance(v, NF4Tensor):
Expand All @@ -62,10 +62,10 @@ def reparametrize_as_dtype_state_dict_post_hook(
def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
*args: Any,
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
**kwargs: Any,
):
"""
A state_dict hook that replaces NF4 tensors with their restored
Expand All @@ -88,10 +88,10 @@ def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
*args (Any): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
**kwargs (Any): Unused keyword args passed when running this as a state_dict hook.
"""
# Create a state dict of FakeTensors that matches the state_dict
mode = FakeTensorMode()
Expand Down
Loading