Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU as a backend #1826

Merged
merged 1 commit into from
Nov 6, 2024
Merged

Conversation

noemotiovon
Copy link
Contributor

@noemotiovon noemotiovon commented Oct 14, 2024

What does this PR do?

Overview

🚀This PR enables the users of torhtune to leverage the Ascend NPU for better performance in inferencing.

This PR primarily addresses the initial refactoring of device-independent code. In upcoming changes, we’ll focus on further adjustments, using NPU as an example to refine each recipe and complete the remaining device-independent modifications. For now, this PR only touches on recipe lora_finetune_single_device and full_finetune_single_device.

For more details, see: [#1797].

Environment

  • OS: ubuntu 20.04
  • NPU: Atlas 300T A2
  • CANN: 8.0.RC2
  • torch-npu: 2.4.0 rc1
  • torch: 2.4.0

Note

To properly install CANN, see [here] for more details.

The version of torch-npu should match that of torch, see [here] for more details.

In addition, torch_npu has a pre-release version, 2.4.0 RC1, which is also the basis for this test. For more information, please visit [here].

Examples

To start with, the library torch_npu should be correctly installed and imported. Part of the codes are showed below:

torchtune/utils/_device_support.py:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

def is_torch_npu_available():
    try:
        import torch_npu # noqa: F401
    except ImportError:
        return False
    return torch.npu.is_available()

Plus, there are some other places of the codes might be adjusted, which won't be too much.

Feel free to leave comments to guide me in further improvements 😊.

Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1826

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1485314 with merge base 24d3579 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@noemotiovon noemotiovon marked this pull request as draft October 14, 2024 10:21
torchtune/utils/_device_support.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
@noemotiovon noemotiovon marked this pull request as ready for review October 21, 2024 10:49
@noemotiovon
Copy link
Contributor Author

Hi @ebsmothers, @RdoubleA:

I hope you’re doing well! Could you please help me review my code? I would really appreciate it if you could take a look and share any feedback or suggestions. Thank you so much in advance for your time and support! 😊

Best regards

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:

  1. Do we expect compile to work? If so, we should test that. If not, we could raise an error
  2. Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
  3. PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?

torchtune/training/precision.py Outdated Show resolved Hide resolved
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
recipes/knowledge_distillation_single_device.py Outdated Show resolved Hide resolved
tests/torchtune/utils/test_device.py Show resolved Hide resolved
torchtune/training/_activation_offloading.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Show resolved Hide resolved
@elfisworking
Copy link

elfisworking commented Oct 22, 2024

distributed training seems to have problems e.g qat_distributed @noemotiovon
function torchtune/training/_distributed.py/load_from_full_model_state_dict
sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error.
Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

@noemotiovon
Copy link
Contributor Author

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@elfisworking
Copy link

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@noemotiovon through 126 email thanks. Looking forward to your email.

torchtune/utils/_device.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Outdated Show resolved Hide resolved
@noemotiovon
Copy link
Contributor Author

Basic Usage Test

A single-device fine-tuning process was performed on the Llama 3.1 8B model using the LoRA (Low-Rank Adaptation) technique.

  • Recipe: lora_finetune_single_device

  • Model: Meta-Llama-3.1-8B-Instruct

  • Config:

    # Config for single device LoRA finetuning in lora_finetune_single_device.py
    # using a Llama3.1 8B Instruct model
    #
    # This config assumes that you've run the following command before launching
    # this run:
    #   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
    #
    # To launch on a single device, run the following command from root:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
    #
    # You can add specific overrides through the command line. For example
    # to override the checkpointer directory while launching training
    # you can run:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
    #
    # This config works only for training on single device.
    
    
    # Model Arguments
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      lora_attn_modules: ['q_proj', 'v_proj']
      apply_lora_to_mlp: False
      apply_lora_to_output: False
      lora_rank: 8
      lora_alpha: 16
      lora_dropout: 0.0
    
    # Tokenizer
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
      max_seq_len: null
    
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files: [
        model-00001-of-00004.safetensors,
        model-00002-of-00004.safetensors,
        model-00003-of-00004.safetensors,
        model-00004-of-00004.safetensors
      ]
      recipe_checkpoint: null
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      model_type: LLAMA3
    resume_from_checkpoint: False
    save_adapter_weights_only: False
    
    # Dataset and Sampler
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    seed: null
    shuffle: True
    batch_size: 30
    
    # Optimizer and Scheduler
    optimizer:
      _component_: torch.optim.AdamW
      fused: False
      weight_decay: 0.01
      lr: 3e-4
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    
    # Training
    epochs: 1
    max_steps_per_epoch: null
    gradient_accumulation_steps: 64
    compile: False
    
    # Logging
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: ${output_dir}
    log_every_n_steps: 1
    log_peak_memory_stats: False
    
    # Environment
    device: npu
    dtype: bf16
    
    # Activations Memory
    enable_activation_checkpointing: True
    enable_activation_offloading: False
    
    # Profiler (disabled)
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      enabled: False
    
      #Output directory of trace artifacts
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    
      #`torch.profiler.ProfilerActivity` types to trace
      cpu: True
      cuda: True
    
      #trace options passed to `torch.profiler.profile`
      profile_memory: False
      with_stack: False
      record_shapes: True
      with_flops: False
    
      # `torch.profiler.schedule` options:
      # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
      wait_steps: 5
      warmup_steps: 5
      active_steps: 2
      num_cycles: 1
  • Logs:

    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 4173222699. Local seed is seed + rank = 4173222699 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728874769.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                                            | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
      4%|█████▌                                                                                                                                           | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   4%|████▍                                                                                                              | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  23%|██████████████████████████▌                                                                                        | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  23%|██████████████████████████▊                                                                                         | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  27%|███████████████████████████████▏                                                                                    | 7/26 [49:45<2:14:1|7|Loss: 1.7698196172714233:  27%|██████████████████████████████▉                                                                                    | 7/26 [49:45<2:14:29, 424.69s/it]
     *  History restored 
    
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$  
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ 
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ tune run lora_finetune_single_device --config my_custom_config.yaml
    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1031355438. Local seed is seed + rank = 1031355438 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728878132.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                             | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:04:04<00:00, 418.75s/it]INFO:torchtune.utils._logging:Starting checkpoint save...
    INFO:torchtune.utils._logging:Model checkpoint of size 4.98 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0001_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0002_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 4.92 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0003_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 1.17 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0004_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_model.bin
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_config.json
    INFO:torchtune.utils._logging:Saving final epoch checkpoint.
    INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
    INFO:torchtune.utils._logging:Checkpoint saved in 65.93 seconds.
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:11:40<00:00, 442.34s/it]
    
  • Result: The test results demonstrate the successful completion of a single-device LoRA fine-tuning process on the Llama 3.1 8B model. The configuration included a batch size of 30, gradient accumulation over 64 steps, and one epoch of training on an NPU device using the bf16 data type. Activation checkpointing was enabled, and LoRA fine-tuning was applied to attention modules. The process utilized AdamW as the optimizer with a learning rate of 0.0003 and a cosine learning rate scheduler.

@noemotiovon
Copy link
Contributor Author

Basic Usage Test

A single-device full fine-tuning process was performed on the Qwen2 0.5B model using the LoRA (Low-Rank Adaptation) technique.

  • Recipe: full_finetune_single_device

  • Model: Qwen2-0.5B-Instruct

  • Config:

    # Tokenizer
    tokenizer:
      _component_: torchtune.models.qwen2.qwen2_tokenizer
      path: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/vocab.json
      merges_file: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/merges.txt
      max_seq_len: null
    
    # Dataset
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    seed: null
    shuffle: True
    
    # Model Arguments
    model:
      _component_: torchtune.models.qwen2.qwen2_0_5b
    
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct
      checkpoint_files: [
        model.safetensors
      ]
      recipe_checkpoint: null
      output_dir: /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune
      model_type: QWEN2
    resume_from_checkpoint: False
    
    # Fine-tuning arguments
    batch_size: 15
    epochs: 1
    optimizer:
      _component_: torch.optim.AdamW
      fused: False
      lr: 2e-5
    
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    optimizer_in_bwd: False
    
    max_steps_per_epoch: null
    gradient_accumulation_steps: 8
    compile: False
    
    # Training environment
    device: npu
    
    # Memory management
    enable_activation_checkpointing: True
    
    # Reduced precision
    dtype: bf16
    
    # Logging
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: ${output_dir}
    output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    log_every_n_steps: 1
    log_peak_memory_stats: False
  • Logs:

    NFO:torchtune.utils._logging:Running FullFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 15
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct
      checkpoint_files:
      - model.safetensors
      model_type: QWEN2
      output_dir: /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    epochs: 1
    gradient_accumulation_steps: 8
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    model:
      _component_: torchtune.models.qwen2.qwen2_0_5b
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 2.0e-05
    optimizer_in_bwd: false
    output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
    resume_from_checkpoint: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.qwen2.qwen2_tokenizer
      max_seq_len: null
      merges_file: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/merges.txt
      path: /home/lcg/.cache/modelscope/hub/Qwen/Qwen2-0___5B-Instruct/vocab.json
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3364767838. Local seed is seed + rank = 3364767838 + 0
    Writing logs to /tmp/Qwen2-0.5B-Instruct-finetune/log_1729914193.txt
    /home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
      if self.device.type != 'cpu':
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 1.55 GiB
            NPU peak memory reserved: 1.61 GiB
            NPU peak memory active: 1.55 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer is initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:No learning rate scheduler configured. Using constant learning rate.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                                                    | 0/431 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
    1|431|Loss: 1.136042833328247: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 431/431 [31:03<00:00,  4.45s/it]INFO:torchtune.utils._logging:Model checkpoint of size 0.99 GB saved to /home/lcg/tmp/torchtune/Qwen2-0.5B-Instruct-finetune/hf_model_0001_0.pt
    INFO:torchtune.utils._logging:Saving final epoch checkpoint.
    INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
    1|431|Loss: 1.136042833328247: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 431/431 [31:07<00:00,  4.33s/it]

@noemotiovon
Copy link
Contributor Author

noemotiovon commented Oct 26, 2024

Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:

  1. Do we expect compile to work? If so, we should test that. If not, we could raise an error
  2. Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
  3. PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?

Hi @ebsmothers, Thank you very much for reviewing my code! ☺️ I’ve made the suggested changes, and the goal of this PR is to accomplish the device-independent modifications for torchtune, using NPU as an example. This will involve adapting all recipes to ultimately make torchtune device-independent, with this PR specifically covering full_finetune_single_device and lora_finetune_single_device. Regarding the third point, torch-npu will release the 2.5.0 RC version on November 7, and I’ll be optimizing the code based on PyTorch’s new features as well! Hope you have a fantastic day ahead!

@noemotiovon
Copy link
Contributor Author

Hi @ebsmothers, could you please take a moment to review the code ☺️ ? This update currently supports full_finetune_single_device and lora_finetune_single_device on NPU, and we’ll be adding support for additional recipes in the future. I really appreciate your help—thank you!
Best regards

@ebsmothers
Copy link
Contributor

Hi @noemotiovon sorry for the delay! I will take a look tomorrow if that's alright. Until then I'll tag @RdoubleA and @joecummings in case either of them gets a minute to take a look

@noemotiovon
Copy link
Contributor Author

Hi @ebsmothers, when you have a moment, could you take a quick look at the recent changes I made? Your feedback would be greatly appreciated. Thank you!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @noemotiovon for your patience! I left a handful more comments, please let me know if anything is unclear

@noemotiovon
Copy link
Contributor Author

Thanks @noemotiovon for your patience! I left a handful more comments, please let me know if anything is unclear

Thank you for your review! Your feedback is very clear, and I will make the necessary code changes as soon as possible based on your suggestions. ☺️

@noemotiovon
Copy link
Contributor Author

Hi @ebsmothers, I’ve made the code changes based on your suggestions; could you please review it again? ☺️

Additionally:

  1. For NPU devices, it currently checks for bf16 support based only on the device model, encapsulated in the torch.npu.is_bf16_supported() method.
  2. Support for distributed functionality is still being debugged and will be gradually integrated into another PR.

Best regards

torchtune/utils/_device.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Outdated Show resolved Hide resolved
torchtune/utils/_device.py Outdated Show resolved Hide resolved
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @noemotiovon for the updates! I left a couple more comments but I think this is pretty close now. It looks like a unit test is failing in CI though, can you take a look? Happy to provide any debugging pointers if you need

@noemotiovon
Copy link
Contributor Author

Thanks @noemotiovon for the updates! I left a couple more comments but I think this is pretty close now. It looks like a unit test is failing in CI though, can you take a look? Happy to provide any debugging pointers if you need

Hi @ebsmothers,
Your review comments were excellent, and I’ve learned a lot from you! I’ve made changes based on your feedback and fixed the CI failure issue. Could you please review my code again?

Best regards

@ebsmothers
Copy link
Contributor

Hi @noemotiovon looks like CI is green now but there are merge conflicts. Can you pull from latest main and merge with your changes?

@noemotiovon
Copy link
Contributor Author

Hi @noemotiovon looks like CI is green now but there are merge conflicts. Can you pull from latest main and merge with your changes?

Hi @ebsmothers, I’d be happy to! I’ve resolved the conflicts and made the necessary adjustments. Could you please help merge it? Thank you for your continued support, and wishing you all the best in work and life! 😊


@RdoubleA RdoubleA merged commit 4389b4d into pytorch:main Nov 6, 2024
17 checks passed
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants