Skip to content

Commit

Permalink
consolidate checks
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Dec 10, 2024
1 parent 0699aa3 commit be9c0fb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
18 changes: 11 additions & 7 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr
from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer

from tqdm import tqdm

Expand Down Expand Up @@ -182,9 +183,14 @@ def __init__(self, cfg: DictConfig) -> None:
)

if cfg.mixed_precision.enabled:
if not cfg.compile or not cfg.dataset.packed:
raise ValueError(
"When mixed_precision.enabled is True, both compile and dataset.packed must be True."
if (
cfg.mixed_precision._component_
== "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer"
):
Int8MixedPrecisionTrainingQuantizer.validate_config(
compile=cfg.compile,
dataset_packed=cfg.dataset.packed,
optimizer_path=cfg.optimizer._component_,
)

# These are public properties which are updated by the checkpoint loader
Expand Down Expand Up @@ -274,7 +280,7 @@ def setup(self, cfg: DictConfig) -> None:
model_state_dict=checkpoint_dict[training.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
mixed_precision_cfg=cfg.get("mixed_precision", None),
mixed_precision_cfg=cfg.mixed_precision,
)
self._tokenizer = config.instantiate(cfg.tokenizer)

Expand Down Expand Up @@ -512,9 +518,7 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if mixed_precision_cfg is not None and mixed_precision_cfg.get(
"enabled", False
):
if mixed_precision_cfg is not None and mixed_precision_cfg.enabled:
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
Expand Down
18 changes: 11 additions & 7 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.lr_schedulers import get_lr
from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer

from tqdm import tqdm

Expand Down Expand Up @@ -183,9 +184,14 @@ def __init__(self, cfg: DictConfig) -> None:
)

if cfg.mixed_precision.enabled:
if not cfg.compile or not cfg.dataset.packed:
raise ValueError(
"When mixed_precision.enabled is True, both compile and dataset.packed must be True."
if (
cfg.mixed_precision._component_
== "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer"
):
Int8MixedPrecisionTrainingQuantizer.validate_config(
compile=cfg.compile,
dataset_packed=cfg.dataset.packed,
optimizer_path=cfg.optimizer._component_,
)

# These are public properties which are updated by the checkpoint loader
Expand Down Expand Up @@ -277,7 +283,7 @@ def setup(self, cfg: DictConfig) -> None:
enable_activation_offloading=self._enable_activation_offloading,
compile_model=self._compile,
model_state_dict=ckpt_dict[training.MODEL_KEY],
mixed_precision_cfg=cfg.get("mixed_precision", None),
mixed_precision_cfg=cfg.mixed_precision,
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
Expand Down Expand Up @@ -437,9 +443,7 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if mixed_precision_cfg is not None and mixed_precision_cfg.get(
"enabled", False
):
if mixed_precision_cfg is not None and mixed_precision_cfg.enabled:
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
Expand Down
21 changes: 21 additions & 0 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,27 @@ def __init__(
grad_weight=grad_weight,
)

@staticmethod
def validate_config(
*, compile: bool, dataset_packed: bool, optimizer_path: str
) -> None:
if not (compile and dataset_packed):
raise ValueError(
"Both compile and dataset.packed must be True to use INT8 mixed-precision training."
)

if not optimizer_path.startswith("torch.optim."):
warn(
"Using low-bit optimizer might have convergence issues with INT8 mixed-precision training. "
"If you observe divergence, try again with the standard torch.optim.AdamW instead."
)

warn(
"INT8 mixed-precision might not speedup training if model and/or batch size is too small "
"for the current GPU(s). If it is the case, try increasing batch size or sequence length. "
"On A100, Llama-3-8B only has speedup for batch_size=4, max_seq_len=2048 and above."
)

def prepare(self, model: nn.Module) -> nn.Module:
# we use module-swap implementation so that the state_dict remains plain tensors,
# as well as better FSDP compatibility in torchtune.
Expand Down

0 comments on commit be9c0fb

Please sign in to comment.