diff --git a/.github/workflows/recipe_test_multi_gpu.yaml b/.github/workflows/recipe_test_multi_gpu.yaml index 7b5c182c82..2d66feaa64 100644 --- a/.github/workflows/recipe_test_multi_gpu.yaml +++ b/.github/workflows/recipe_test_multi_gpu.yaml @@ -24,6 +24,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] + torch-version: ["nightly", "stable"] steps: - name: Check out repo uses: actions/checkout@v3 @@ -45,9 +46,14 @@ jobs: - name: Install S3 CLI run: | python -m pip install awscli==1.32.6 - - name: Install dependencies + - name: Install torch nightly + if: ${{ matrix.torch-version == 'nightly' }} + run: python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + - name: Install torch stable + if: ${{ matrix.torch-version == 'stable' }} + run: python -m pip install torch + - name: Install remaining dependencies run: | - python -m pip install torch python -m pip install -e ".[dev]" python -m pip install lm-eval==0.4.* - name: Run recipe tests with coverage diff --git a/README.md b/README.md index df1de940f7..bf732beb30 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,13 @@ Check out `tune --help` for all possible CLI commands and options. --- +#### Support for torch.compile + +Note that support for ``torch.compile`` is currently in progress and only enabled for a couple of recipes. In particular, ``torch.compile`` support is +enabled for [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_single_device.py) and [lora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py) and validated on Llama2-7b. Other recipes and models are currently not tested, though we welcome contributions from the community. + +--- + ## Design Principles TorchTune embodies PyTorch’s design philosophy [[details](https://pytorch.org/docs/stable/community/design.html)], especially "usability over everything else". diff --git a/docs/source/examples/finetune_llm.rst b/docs/source/examples/finetune_llm.rst index bc454e1415..33223f0341 100644 --- a/docs/source/examples/finetune_llm.rst +++ b/docs/source/examples/finetune_llm.rst @@ -70,7 +70,7 @@ To run the recipe without any changes on 4 GPUs, launch a training run using Tun .. code-block:: bash - tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config full_finetune_distributed + tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config full_finetune_distributed Dataset ------- diff --git a/docs/source/examples/lora_finetune.rst b/docs/source/examples/lora_finetune.rst index 5f14b236aa..6a3fe5244b 100644 --- a/docs/source/examples/lora_finetune.rst +++ b/docs/source/examples/lora_finetune.rst @@ -253,7 +253,7 @@ You can then run the following command to perform a LoRA finetune of Llama2-7B u .. code-block:: bash - tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed .. note:: Make sure to point to the location of your Llama2 weights and tokenizer. This can be done @@ -288,7 +288,7 @@ Let's run this experiment. We can also increase alpha (in general it is good pra .. code-block:: bash - tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed \ + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed \ lora_attn_modules='[q_proj, k_proj, v_proj, output_proj]' \ lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1 diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index ebabb23c24..53dc03f6e7 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -3,18 +3,18 @@ # # This config assumes that you've run the following command before launching # this run: -# tune download --repo-id google/gemma-2b \ +# tune download google/gemma-2b \ # --hf-token \ # --output-dir /tmp/gemma2 # # To launch on 4 devices, run the following command from root: -# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ # --config gemma/2B_full \ # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ # --config gemma/2B_full \ # checkpointer.checkpoint_dir= # diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index abbd9c45c5..73c9c0c922 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ # --config llama2/13B_full \ # checkpointer.checkpoint_dir= # diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 947faf7c6a..a82d2dde6b 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ # --config llama2/13B_lora \ # checkpointer.checkpoint_dir= # diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 16f3dcb3ec..c510efa8b4 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ # --config llama2/7B_full \ # checkpointer.checkpoint_dir= # diff --git a/recipes/configs/llama2/7B_full_single_device.yaml b/recipes/configs/llama2/7B_full_single_device.yaml index 1d297a28ec..b1d4e59a90 100644 --- a/recipes/configs/llama2/7B_full_single_device.yaml +++ b/recipes/configs/llama2/7B_full_single_device.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ +# tune run --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ # --config llama2/7B_full_single_device \ # checkpointer.checkpoint_dir= # @@ -56,6 +56,7 @@ loss: _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False optimizer_in_bwd: False diff --git a/recipes/configs/llama2/7B_full_single_device_low_memory.yaml b/recipes/configs/llama2/7B_full_single_device_low_memory.yaml index c1bfd5cb6f..4c36bf774d 100644 --- a/recipes/configs/llama2/7B_full_single_device_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_single_device_low_memory.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ +# tune run --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ # --config llama2/7B_full_single_device_low_memory \ # checkpointer.checkpoint_dir= # @@ -68,6 +68,9 @@ enable_activation_checkpointing: True # Reduced precision dtype: bf16 +# Model compilation +compile: False + # Logging metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 16053b7168..ddfe03f9cb 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ # --config llama2/7B_lora \ # checkpointer.checkpoint_dir= # diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index 249baf7187..eb5d0448e2 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# tune run --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ # --config 7B_lora_single_device \ # checkpointer.checkpoint_dir= # @@ -69,6 +69,7 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # Logging output_dir: /tmp/lora_finetune_output diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index 644541b1bf..7f916b2b55 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -14,7 +14,7 @@ # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# tune run --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ # --config 7B_qlora_single_device \ # checkpointer.checkpoint_dir= # @@ -69,6 +69,9 @@ loss: epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 1 +# Note: compile for QLoRA is only supported on nightly +# PyTorch (>= 2.4.0.dev20240408) +compile: False # Logging output_dir: /tmp/qlora_finetune_output/ diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index db1e3ce6f3..75e513defd 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -101,13 +101,13 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 - def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. If resume_from_checkpoint is True, this also includes the recipe state. """ self._checkpointer = config.instantiate( - cfg, + cfg_checkpointer, resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 96167713ba..3658fa8205 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import sys - from functools import partial from typing import Any, Dict, Optional, Tuple from warnings import warn @@ -103,13 +102,13 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 - def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. If resume_from_checkpoint is True, this also includes the recipe state. """ self._checkpointer = config.instantiate( - cfg, + cfg_checkpointer, resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -156,9 +155,11 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model + self._model_compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=self._model_compile, model_state_dict=ckpt_dict[utils.MODEL_KEY], ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -206,6 +207,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + compile_model: bool, model_state_dict: Dict[str, Any], ) -> nn.Module: """ @@ -224,6 +226,11 @@ def _setup_model( # Validate model was loaded in with the expected dtype. utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) log.info(f"Model is initialized with precision {self._dtype}.") + + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + model = utils.wrap_compile(model) log.info( utils.memory_stats_log( "Memory Stats after model init:", device=self._device @@ -341,6 +348,11 @@ def train(self) -> None: The core training loop. Supports training on subsets of the dataset using the ``max_steps_per_epoch``. """ + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) # zero out the gradients before starting training if not self._optimizer_in_bwd: self._optimizer.zero_grad() diff --git a/recipes/gemma_full_finetune_distributed.py b/recipes/gemma_full_finetune_distributed.py index dad0ded7bb..cd918a9ebc 100644 --- a/recipes/gemma_full_finetune_distributed.py +++ b/recipes/gemma_full_finetune_distributed.py @@ -100,13 +100,13 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 - def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. If resume_from_checkpoint is True, this also includes the recipe state. """ self._checkpointer = config.instantiate( - cfg, + cfg_checkpointer, resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 3f4711eb72..335fd11169 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -98,14 +98,14 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. This includes the base model weights. If resume_from_checkpoint is True, this also includes the adapter weights and recipe state """ self._checkpointer = config.instantiate( - cfg, + cfg_checkpointer, resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -146,12 +146,13 @@ def setup(self, cfg: DictConfig) -> None: model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. """ self._metric_logger = config.instantiate(cfg.metric_logger) - - checkpoint_dict = self.load_checkpoint(cfg=cfg.checkpointer) + self._model_compile = cfg.compile + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], lora_weights_state_dict=( checkpoint_dict[utils.ADAPTER_KEY] @@ -214,6 +215,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: @@ -257,6 +259,10 @@ def _setup_model( ) log.info(f"Model is initialized with precision {self._dtype}.") + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + model = utils.wrap_compile(model) log.info( utils.memory_stats_log( "Memory Stats after model init:", device=self._device @@ -380,6 +386,11 @@ def train(self) -> None: The core training loop. """ + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs diff --git a/tests/cache_artifacts.sh b/tests/cache_artifacts.sh index 0a54c593c9..df4079670d 100755 --- a/tests/cache_artifacts.sh +++ b/tests/cache_artifacts.sh @@ -13,7 +13,6 @@ # In all cases, if the files already exist locally they will not be downloaded from S3. SMALL_MODEL_URLS=( - "s3://pytorch-multimodal/small-ckpt-01242024" "s3://pytorch-multimodal/small-ckpt-tune-03082024.pt" "s3://pytorch-multimodal/small-ckpt-meta-03082024.pt" "s3://pytorch-multimodal/small-ckpt-hf-03082024.pt" diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 55a95d2b55..b992df6bc9 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -50,12 +50,16 @@ def _fetch_expected_loss_values(self): @pytest.mark.parametrize( "config", ["full_single_device_low_memory", "full_single_device"] ) - def test_loss(self, config, tmpdir, monkeypatch): + @pytest.mark.parametrize("compile", [True, False]) + def test_loss(self, compile, config, tmpdir, monkeypatch): ckpt = "small_test_ckpt_meta" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" cmd = f""" tune run full_finetune_single_device \ --config llama2/7B_{config} \ @@ -66,6 +70,7 @@ def test_loss(self, config, tmpdir, monkeypatch): checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ metric_logger.filename={log_file} \ + compile={compile} \ """.split() model_config = llama2_test_config() diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index b74d7d8faf..3eece0ef24 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -22,6 +22,7 @@ CKPT_MODEL_PATHS, gen_log_file_name, get_loss_values_from_metric_logger, + torch_version_ge, ) from torchtune import config @@ -52,12 +53,17 @@ def _fetch_qlora_expected_loss_values(self, dtype): return [10.5059, 10.5571, 10.5181, 10.4897] @pytest.mark.integration_test - def test_loss(self, tmpdir, monkeypatch): + @pytest.mark.parametrize("compile", [True, False]) + def test_loss(self, compile, tmpdir, monkeypatch): ckpt = "small_test_ckpt_meta" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" + cmd = f""" tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ @@ -68,6 +74,7 @@ def test_loss(self, tmpdir, monkeypatch): checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ metric_logger.filename={log_file} \ + compile={compile} \ """.split() model_config = lora_llama2_test_config( @@ -91,12 +98,21 @@ def test_loss(self, tmpdir, monkeypatch): @pytest.mark.integration_test @pytest.mark.parametrize("dtype", ["fp32", "bf16"]) - def test_loss_qlora(self, dtype, tmpdir, monkeypatch): + @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.skipif( + not torch_version_ge("2.4.0"), + reason="Please install a nightly build of torch to run this test.", + ) + def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): ckpt = "small_test_ckpt_meta" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) + # To workaround https://github.com/pytorch/torchtune/issues/676 + if compile: + os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager" + cmd = f""" tune run lora_finetune_single_device --config llama2/7B_qlora_single_device \ @@ -107,6 +123,7 @@ def test_loss_qlora(self, dtype, tmpdir, monkeypatch): checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ metric_logger.filename={log_file} \ + compile={compile} \ """.split() model_config = lora_llama2_test_config( diff --git a/tests/regression_tests/test_llama2_7b.py b/tests/regression_tests/test_llama2_7b.py index a4743711ce..32de865b02 100644 --- a/tests/regression_tests/test_llama2_7b.py +++ b/tests/regression_tests/test_llama2_7b.py @@ -55,7 +55,7 @@ def test_loss(self, tmpdir, monkeypatch): log_file = gen_log_file_name(tmpdir) cmd = f""" - tune --nnodes 1 --nproc_per_node 2 full_finetune_distributed + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer @@ -99,7 +99,7 @@ def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): # Run on prod LoRA FT config but with only 10 steps for now ft_cmd = f""" - tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed + tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer diff --git a/tests/test_utils.py b/tests/test_utils.py index 2ca3fcfee6..c575134e8f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,6 +31,14 @@ "llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt", } + +def torch_version_ge(version: str) -> bool: + """ + Check if torch version is greater than or equal to the given version + """ + return version in torch.__version__ or torch.__version__ >= version + + # Inherit from tokenizer class to reuse its tokenize_messages method class DummyTokenizer(Tokenizer): def __init__(self): diff --git a/tests/torchtune/utils/test_wrap_compile.py b/tests/torchtune/utils/test_wrap_compile.py new file mode 100644 index 0000000000..83e4707eb6 --- /dev/null +++ b/tests/torchtune/utils/test_wrap_compile.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import torch +from torchtune import utils + + +class TestWrapCompile: + def test_wrap_compile(self) -> None: + """ + Ensures that compile prefix is removed in compiled model + state_dict and can be loaded into non-compiled model. + """ + m = torch.nn.Linear(5, 5) + m = utils.wrap_compile(m) + assert isinstance(m, torch._dynamo.eval_frame.OptimizedModule) + load_m = torch.nn.Linear(5, 5) + missing, unexpected = load_m.load_state_dict(m.state_dict(), strict=False) + assert not missing + assert not unexpected diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 27365deb56..7d23fc723a 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -11,6 +11,8 @@ ModelType, transform_opt_state_dict, ) + +from ._compile_utils import wrap_compile from ._device import get_device from ._distributed import ( # noqa contains_fsdp, @@ -77,6 +79,7 @@ "set_default_dtype", "set_seed", "validate_expected_param_dtype", + "wrap_compile", "TuneRecipeArgumentParser", "CheckpointableDataLoader", "OptimizerInBackwardWrapper", diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 7c4e589a64..35352d4430 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -265,7 +265,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface): the following flow: 1. Download the model from the HF repo using tune download - tune download --repo-id meta-llama/Llama-2-7b-hf \ + tune download meta-llama/Llama-2-7b-hf \ --output-dir \ --hf-token @@ -511,7 +511,7 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): the following flow: 1. Download the model from the HF repo using tune download - tune download --repo-id meta-llama/Llama-2-7b \ + tune download meta-llama/Llama-2-7b \ --output-dir \ --hf-token diff --git a/torchtune/utils/_compile_utils.py b/torchtune/utils/_compile_utils.py new file mode 100644 index 0000000000..273292c026 --- /dev/null +++ b/torchtune/utils/_compile_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from torch import nn + +_TORCH_COMPILE_WRAPPER_PREFIX = "_orig_mod." + + +def wrap_compile(model: nn.Module) -> None: + """ + Wraps the model with torch.compile. This function will also + register a state_dict post hook that allows state_dicts produced + with torch.compile training to behave as regular eager mode models. + In particular, it strips away a torch.compile specific prefix + added to the state_dict by torch.compile. + + Args: + model (nn.Module): model to wrap with compile. + + Returns: + None + """ + # TORCH_COMPILE_BACKEND can be set as an env var to override default torch.compile backend. + # Currently only used in unittesting to work around https://github.com/pytorch/torchtune/issues/676 + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model = torch.compile(model, backend=backend) + model._register_state_dict_hook(_remove_torch_compile_prefix) + return model + + +def _remove_torch_compile_prefix(model, state_dict, *args, **kwargs): + keys = list(state_dict.keys()) + for key in keys: + if key.startswith(_TORCH_COMPILE_WRAPPER_PREFIX): + newkey = key[len(_TORCH_COMPILE_WRAPPER_PREFIX) :] + state_dict[newkey] = state_dict.pop(key)