From 9685e0bda9b254a1cd223d5840736498c40626cb Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 11 Oct 2024 17:32:00 +0200 Subject: [PATCH 01/11] feat: add gemma2b variants --- README.md | 1 + docs/source/api_ref_models.rst | 31 ++ recipes/configs/gemma2/27B_full.yaml | 95 ++++ recipes/configs/gemma2/27B_lora.yaml | 107 +++++ .../gemma2/27B_lora_single_device.yaml | 134 ++++++ .../gemma2/27B_qlora_single_device.yaml | 137 ++++++ recipes/configs/gemma2/2B_full.yaml | 74 ++++ recipes/configs/gemma2/2B_lora.yaml | 86 ++++ .../configs/gemma2/2B_lora_single_device.yaml | 113 +++++ .../gemma2/2B_qlora_single_device.yaml | 113 +++++ recipes/configs/gemma2/9B_full.yaml | 79 ++++ recipes/configs/gemma2/9B_lora.yaml | 91 ++++ .../configs/gemma2/9B_lora_single_device.yaml | 118 +++++ .../gemma2/9B_qlora_single_device.yaml | 121 ++++++ torchtune/_recipe_registry.py | 30 ++ torchtune/models/convert_weights.py | 6 +- torchtune/models/gemma/__init__.py | 2 - torchtune/models/gemma/_component_builders.py | 6 - torchtune/models/gemma2/__init__.py | 36 ++ .../models/gemma2/_component_builders.py | 410 ++++++++++++++++++ torchtune/models/gemma2/_model_builders.py | 286 ++++++++++++ torchtune/modules/attention.py | 289 ++++++++++++ torchtune/training/checkpointing/_utils.py | 2 + 23 files changed, 2357 insertions(+), 10 deletions(-) create mode 100644 recipes/configs/gemma2/27B_full.yaml create mode 100644 recipes/configs/gemma2/27B_lora.yaml create mode 100644 recipes/configs/gemma2/27B_lora_single_device.yaml create mode 100644 recipes/configs/gemma2/27B_qlora_single_device.yaml create mode 100644 recipes/configs/gemma2/2B_full.yaml create mode 100644 recipes/configs/gemma2/2B_lora.yaml create mode 100644 recipes/configs/gemma2/2B_lora_single_device.yaml create mode 100644 recipes/configs/gemma2/2B_qlora_single_device.yaml create mode 100644 recipes/configs/gemma2/9B_full.yaml create mode 100644 recipes/configs/gemma2/9B_lora.yaml create mode 100644 recipes/configs/gemma2/9B_lora_single_device.yaml create mode 100644 recipes/configs/gemma2/9B_qlora_single_device.yaml create mode 100644 torchtune/models/gemma2/__init__.py create mode 100644 torchtune/models/gemma2/_component_builders.py create mode 100644 torchtune/models/gemma2/_model_builders.py diff --git a/README.md b/README.md index a66d3ded4c..568847abd5 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ torchtune currently supports the following models. | [Code-Llama2](https://ai.meta.com/blog/code-llama-large-language-model-coding/) | 7B, 13B, 70B [[models](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | | [Mistral](https://huggingface.co/mistralai) | 7B [[models](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | | [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B, 7B [[models](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | +| [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2) | 2B, 9B, 27B [[models](torchtune/models/gemma2/_model_builders.py), [configs](recipes/configs/gemma2/)] | | [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[models](torchtune/models/phi3/), [configs](recipes/configs/phi3/)] | [Qwen2](https://qwenlm.github.io/blog/qwen2/) | 0.5B, 1.5B, 7B [[models](torchtune/models/qwen2/), [configs](recipes/configs/qwen2/)] diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index fe94104484..39d6ec2291 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -320,6 +320,37 @@ To download the Gemma 7B model: gemma.gemma_tokenizer +gemma2 : # TODO +----- + +Models of size 2B, 9B, 27B from the `Gemma family `_. + +Important: You need to request access on `Hugging Face `__ to use this model. + +To download the Gemma2 2B, 9B, 27B models : + +.. code-block:: bash + + tune download google/gemma-2-b --ignore-patterns "gemma-2-b.gguf" --hf-token + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + gemma2.gemma2 + gemma2.lora_gemma + gemma2.gemma2_2b + gemma2.lora_gemma2_2b + gemma2.qlora_gemma2_2b + gemma2.gemma2_9b + gemma2.lora_gemma2_9b + gemma2.qlora_gemma2_9b + gemma2.gemma2_27b + gemma2.lora_gemma2_27b + gemma2.qlora_gemma2_27b + gemma.gemma_tokenizer + clip ----- diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml new file mode 100644 index 0000000000..09cf4bbc38 --- /dev/null +++ b/recipes/configs/gemma2/27B_full.yaml @@ -0,0 +1,95 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-27b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma_27b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_files: [ + model-00001-of-00024.safetensors, + model-00002-of-00024.safetensors, + model-00003-of-00024.safetensors, + model-00004-of-00024.safetensors, + model-00005-of-00024.safetensors, + model-00006-of-00024.safetensors, + model-00007-of-00024.safetensors, + model-00008-of-00024.safetensors, + model-00009-of-00024.safetensors, + model-00010-of-00024.safetensors, + model-00011-of-00024.safetensors, + model-00012-of-00024.safetensors, + model-00013-of-00024.safetensors, + model-00014-of-00024.safetensors, + model-00015-of-00024.safetensors, + model-00016-of-00024.safetensors, + model-00017-of-00024.safetensors, + model-00018-of-00024.safetensors, + model-00019-of-00024.safetensors, + model-00020-of-00024.safetensors, + model-00021-of-00024.safetensors, + model-00022-of-00024.safetensors, + model-00023-of-00024.safetensors, + model-00024-of-00024.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-27b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml new file mode 100644 index 0000000000..3631675a18 --- /dev/null +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -0,0 +1,107 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-27b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_files: [ + model-00001-of-00024.safetensors, + model-00002-of-00024.safetensors, + model-00003-of-00024.safetensors, + model-00004-of-00024.safetensors, + model-00005-of-00024.safetensors, + model-00006-of-00024.safetensors, + model-00007-of-00024.safetensors, + model-00008-of-00024.safetensors, + model-00009-of-00024.safetensors, + model-00010-of-00024.safetensors, + model-00011-of-00024.safetensors, + model-00012-of-00024.safetensors, + model-00013-of-00024.safetensors, + model-00014-of-00024.safetensors, + model-00015-of-00024.safetensors, + model-00016-of-00024.safetensors, + model-00017-of-00024.safetensors, + model-00018-of-00024.safetensors, + model-00019-of-00024.safetensors, + model-00020-of-00024.safetensors, + model-00021-of-00024.safetensors, + model-00022-of-00024.safetensors, + model-00023-of-00024.safetensors, + model-00024-of-00024.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml new file mode 100644 index 0000000000..58e7170c8c --- /dev/null +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -0,0 +1,134 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run (torchtune does not use gguf so you can ignore it to save time and space): +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-27b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_files: [ + model-00001-of-00024.safetensors, + model-00002-of-00024.safetensors, + model-00003-of-00024.safetensors, + model-00004-of-00024.safetensors, + model-00005-of-00024.safetensors, + model-00006-of-00024.safetensors, + model-00007-of-00024.safetensors, + model-00008-of-00024.safetensors, + model-00009-of-00024.safetensors, + model-00010-of-00024.safetensors, + model-00011-of-00024.safetensors, + model-00012-of-00024.safetensors, + model-00013-of-00024.safetensors, + model-00014-of-00024.safetensors, + model-00015-of-00024.safetensors, + model-00016-of-00024.safetensors, + model-00017-of-00024.safetensors, + model-00018-of-00024.safetensors, + model-00019-of-00024.safetensors, + model-00020-of-00024.safetensors, + model-00021-of-00024.safetensors, + model-00022-of-00024.safetensors, + model-00023-of-00024.safetensors, + model-00024-of-00024.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 5e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml new file mode 100644 index 0000000000..87c8b25f8c --- /dev/null +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -0,0 +1,137 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-27b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_files: [ + model-00001-of-00024.safetensors, + model-00002-of-00024.safetensors, + model-00003-of-00024.safetensors, + model-00004-of-00024.safetensors, + model-00005-of-00024.safetensors, + model-00006-of-00024.safetensors, + model-00007-of-00024.safetensors, + model-00008-of-00024.safetensors, + model-00009-of-00024.safetensors, + model-00010-of-00024.safetensors, + model-00011-of-00024.safetensors, + model-00012-of-00024.safetensors, + model-00013-of-00024.safetensors, + model-00014-of-00024.safetensors, + model-00015-of-00024.safetensors, + model-00016-of-00024.safetensors, + model-00017-of-00024.safetensors, + model-00018-of-00024.safetensors, + model-00019-of-00024.safetensors, + model-00020-of-00024.safetensors, + model-00021-of-00024.safetensors, + model-00022-of-00024.safetensors, + model-00023-of-00024.safetensors, + model-00024-of-00024.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 + +# For colab use True +low_cpu_ram: False diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml new file mode 100644 index 0000000000..f1214810a9 --- /dev/null +++ b/recipes/configs/gemma2/2B_full.yaml @@ -0,0 +1,74 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma2_2b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-2b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml new file mode 100644 index 0000000000..ca6d8df232 --- /dev/null +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-2b + model_type: GEMMA2 +resume_from_checkpoint: False + +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml new file mode 100644 index 0000000000..d8bbeb9a81 --- /dev/null +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -0,0 +1,113 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-2b + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml new file mode 100644 index 0000000000..c65367419f --- /dev/null +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -0,0 +1,113 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-2b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-2b + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml new file mode 100644 index 0000000000..09d638a3b9 --- /dev/null +++ b/recipes/configs/gemma2/9B_full.yaml @@ -0,0 +1,79 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-9b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma_9b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_files: [ + model-00001-of-00008.safetensors, + model-00002-of-00008.safetensors, + model-00003-of-00008.safetensors, + model-00004-of-00008.safetensors, + model-00005-of-00008.safetensors, + model-00006-of-00008.safetensors, + model-00007-of-00008.safetensors, + model-00008-of-00008.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-9b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml new file mode 100644 index 0000000000..3f27bab651 --- /dev/null +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -0,0 +1,91 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-9b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_files: [ + model-00001-of-00008.safetensors, + model-00002-of-00008.safetensors, + model-00003-of-00008.safetensors, + model-00004-of-00008.safetensors, + model-00005-of-00008.safetensors, + model-00006-of-00008.safetensors, + model-00007-of-00008.safetensors, + model-00008-of-00008.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml new file mode 100644 index 0000000000..73ee146089 --- /dev/null +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -0,0 +1,118 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run (torchtune does not use gguf so you can ignore it to save time and space): +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-9b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_files: [ + model-00001-of-00008.safetensors, + model-00002-of-00008.safetensors, + model-00003-of-00008.safetensors, + model-00004-of-00008.safetensors, + model-00005-of-00008.safetensors, + model-00006-of-00008.safetensors, + model-00007-of-00008.safetensors, + model-00008-of-00008.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 5e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml new file mode 100644 index 0000000000..6ef9a5d785 --- /dev/null +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -0,0 +1,121 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma2-9b/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_files: [ + model-00001-of-00008.safetensors, + model-00002-of-00008.safetensors, + model-00003-of-00008.safetensors, + model-00004-of-00008.safetensors, + model-00005-of-00008.safetensors, + model-00006-of-00008.safetensors, + model-00007-of-00008.safetensors, + model-00008-of-00008.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 + +# For colab use True +low_cpu_ram: False diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index e1c7f8c3c5..9fa5465543 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -93,6 +93,9 @@ class Recipe: Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"), Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"), Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"), + Config(name="gemma2/2B_full", file_path="gemma2/2B_full.yaml"), + Config(name="gemma2/9B_full", file_path="gemma2/9B_full.yaml"), + Config(name="gemma2/27B_full", file_path="gemma2/27B_full.yaml"), Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"), Config(name="qwen2/7B_full", file_path="qwen2/7B_full.yaml"), Config(name="qwen2/0.5B_full", file_path="qwen2/0.5B_full.yaml"), @@ -192,6 +195,30 @@ class Recipe: name="gemma/7B_qlora_single_device", file_path="gemma/7B_qlora_single_device.yaml", ), + Config( + name="gemma2/2B_lora_single_device", + file_path="gemma2/2B_lora_single_device.yaml", + ), + Config( + name="gemma2/2B_qlora_single_device", + file_path="gemma2/2B_qlora_single_device.yaml", + ), + Config( + name="gemma2/9B_lora_single_device", + file_path="gemma2/9B_lora_single_device.yaml", + ), + Config( + name="gemma2/9B_qlora_single_device", + file_path="gemma2/9B_qlora_single_device.yaml", + ), + Config( + name="gemma2/27B_lora_single_device", + file_path="gemma2/27B_lora_single_device.yaml", + ), + Config( + name="gemma2/27B_qlora_single_device", + file_path="gemma2/27B_qlora_single_device.yaml", + ), Config( name="phi3/mini_lora_single_device", file_path="phi3/mini_lora_single_device.yaml", @@ -281,6 +308,9 @@ class Recipe: Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"), Config(name="gemma/2B_lora", file_path="gemma/2B_lora.yaml"), Config(name="gemma/7B_lora", file_path="gemma/7B_lora.yaml"), + Config(name="gemma2/2B_lora", file_path="gemma2/2B_lora.yaml"), + Config(name="gemma2/9B_lora", file_path="gemma2/9B_lora.yaml"), + Config(name="gemma2/27B_lora", file_path="gemma2/27B_lora.yaml"), Config(name="phi3/mini_lora", file_path="phi3/mini_lora.yaml"), Config(name="qwen2/7B_lora", file_path="qwen2/7B_lora.yaml"), Config(name="qwen2/0.5B_lora", file_path="qwen2/0.5B_lora.yaml"), diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index c0cf2f10fc..7333af1838 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -38,8 +38,10 @@ "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", # mlp_norm.scale -> looks like a previous bug here # noqa + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.sa_scale.scale", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.mlp_scale.scale", "model.norm.weight": "norm.scale", "lm_head.weight": "output.weight", } diff --git a/torchtune/models/gemma/__init__.py b/torchtune/models/gemma/__init__.py index 48e4e84b10..f762de86b6 100644 --- a/torchtune/models/gemma/__init__.py +++ b/torchtune/models/gemma/__init__.py @@ -27,6 +27,4 @@ "lora_gemma_7b", "qlora_gemma_2b", "qlora_gemma_7b", - "gemma_hf_to_tune", - "gemma_tune_to_hf", ] diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index e7ab9b224c..ba5b666c98 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -46,7 +46,6 @@ def gemma( attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, - norm_embeddings: bool = True, ) -> TransformerDecoder: """ Build the decoder associated with the gemma model. This includes: @@ -72,8 +71,6 @@ def gemma( Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 - norm_embeddings (bool): whether to apply layer norm before the self-attention - and mlp layers. Default: True Returns: TransformerDecoder: Instantiation of gemma model. @@ -146,7 +143,6 @@ def lora_gemma( attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, - norm_embeddings: bool = True, # LoRA args lora_rank: int, lora_alpha: float, @@ -177,8 +173,6 @@ def lora_gemma( Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 - norm_embeddings (bool): whether to apply layer norm before the self-attention - and mlp layers. Default: True lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 diff --git a/torchtune/models/gemma2/__init__.py b/torchtune/models/gemma2/__init__.py new file mode 100644 index 0000000000..9fe11db7ab --- /dev/null +++ b/torchtune/models/gemma2/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ..gemma._model_builders import gemma_tokenizer +from ..gemma._tokenizer import GemmaTokenizer # noqa +from ._component_builders import gemma2, lora_gemma2 # noqa +from ._model_builders import ( # noqa + gemma2_27b, + gemma2_2b, + gemma2_9b, + lora_gemma2_27b, + lora_gemma2_2b, + lora_gemma2_9b, + qlora_gemma2_27b, + qlora_gemma2_2b, + qlora_gemma2_9b, +) + +__all__ = [ + "GemmaTokenizer", + "gemma2", + "gemma2_2b", + "gemma2_9b", + "gemma2_27b", + "gemma_tokenizer", + "lora_gemma2", + "lora_gemma2_2b", + "lora_gemma2_9b", + "lora_gemma2_27b", + "qlora_gemma2_2b", + "qlora_gemma2_9b", + "qlora_gemma2_27b", +] diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py new file mode 100644 index 0000000000..6c99ccb701 --- /dev/null +++ b/torchtune/models/gemma2/_component_builders.py @@ -0,0 +1,410 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +import torch +from typing import List +from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks +from typing import List, Optional + +from torchtune.modules import ( + FrozenNF4Linear, + RotaryPositionalEmbeddings, + TransformerSelfAttentionLayer, +) + +from torchtune.modules.attention import Gemma2Attention +from torchtune.models.gemma.rms_norm import GemmaRMSNorm +from torchtune.modules import TransformerDecoder, TiedLinear +from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear +from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp + +""" +Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``MultiHeadAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + +class TanhSotfCapping(nn.Module): + def __init__( + self, + capping_value: float, + ) -> None: + super().__init__() + self.capping_value = capping_value + + def forward(self, attn_weights): + attn_weights = attn_weights / self.capping_value + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.capping_value + + +class Gemma2FinalNorm(nn.Module): + """ + Combines RMSNorm and SoftCapping + """ + def __init__( + self, + capping_value: float, + embed_dim: int, + eps: float + ) -> None: + super().__init__() + self.capping_value = capping_value + self.rms_norm = GemmaRMSNorm(embed_dim, eps=eps) + self.logit_capping = TanhSotfCapping(capping_value) + + def forward(self, x): + x = self.rms_norm(x) + x = self.logit_capping(x) + return x + + +def gemma2( + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + hidden_capping_value: float = 50., + final_capping_value: float = 30., + sliding_window_size: int = 4096, + query_pre_attn_scalar: Optional[int] = None, +) -> TransformerDecoder: + """ + Build the decoder associated with the gemma2 model. This includes: + - Token embeddings + - num_layers number of TransformerSelfAttentionLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of gemma model. + """ + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layers = torch.nn.ModuleList() + + for layer_idx in range(num_layers): + self_att = Gemma2Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + # perform sliding window on half of the layers only + sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, + softcapping=hidden_capping_value, + query_pre_attn_scalar=query_pre_attn_scalar + ) + + layer = TransformerSelfAttentionLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + output=output_proj, + head_dim=head_dim, + norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps), + ) + return model + + + +def lora_gemma2( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + *, + # gemma args + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + hidden_capping_value: float = 50., + final_capping_value: float = 30., + sliding_window_size: int = 4096, + query_pre_attn_scalar: Optional[int] = None, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Gemma with LoRA applied based on the passed in configuration. + Note: output projection lora is not supported because it is tied to token embeddings + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Gemma model with LoRA applied to + a subset of the attention projections in each layer. + """ + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + + layers = torch.nn.ModuleList() + + for layer_idx in range(num_layers): + self_att = lora_gemma2_self_attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + # perform sliding window on half of the layers only + sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, + softcapping=hidden_capping_value, + query_pre_attn_scalar=query_pre_attn_scalar + ) + + layer = TransformerSelfAttentionLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + output=output_proj, + head_dim=head_dim, + norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps) + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + _register_reparametrize_state_dict_hooks(model, dtype=tok_embeddings.weight.dtype) + + return model + + +def lora_gemma2_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # MultiHeadAttention args + embed_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 10_000, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50., + query_pre_attn_scalar: Optional[int], + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, + +) -> Gemma2Attention: + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + adapter_cls = DoRALinear if use_dora else LoRALinear + + q_proj = ( + adapter_cls( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + adapter_cls( + num_heads * head_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(num_heads * head_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) + ) + ) + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + self_att = Gemma2Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + sliding_window_size=sliding_window_size, + softcapping=softcapping, + query_pre_attn_scalar=query_pre_attn_scalar + ) + return self_att \ No newline at end of file diff --git a/torchtune/models/gemma2/_model_builders.py b/torchtune/models/gemma2/_model_builders.py new file mode 100644 index 0000000000..72df7747da --- /dev/null +++ b/torchtune/models/gemma2/_model_builders.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from torchtune.models.gemma2._component_builders import gemma2, lora_gemma2 +from torchtune.modules import TransformerDecoder + +from torchtune.modules.peft import LORA_ATTN_MODULES +from functools import partial + +""" +Model builders build specific instantiations using component builders. For example +the ``gemma_2b`` model builder uses the ``gemma2`` component builder. +""" + + +def gemma2_2b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 2B model initialized w/ the default 2b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma2 2B model + """ + return gemma2( + vocab_size=256_000, + num_layers=18, + num_heads=8, + head_dim=256, + num_kv_heads=1, + embed_dim=2048, + intermediate_dim=16384, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + ) + + +def lora_gemma2_2b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma2 2B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_2b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 2B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=18, + num_heads=8, + head_dim=256, + num_kv_heads=1, + embed_dim=2048, + intermediate_dim=16384, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_2b = partial(lora_gemma2_2b, quantize_base=True) + +qlora_gemma2_2b.__doc__ = """ +Builder for creating a Gemma2 model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemm2a_2b` for full API arguments. +""" + + + +def gemma2_9b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 9B model initialized w/ the default 9b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma 9B model + """ + return gemma2( + vocab_size=256_000, + num_layers=42, + num_heads=16, + head_dim=256, + num_kv_heads=16, + embed_dim=3584, + intermediate_dim=14336, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + ) + + +def lora_gemma2_9b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma 9B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 9B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=42, + num_heads=16, + head_dim=256, + num_kv_heads=16, + embed_dim=3584, + intermediate_dim=14336, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_9b = partial(lora_gemma2_9b, quantize_base=True) + +qlora_gemma2_9b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma2_9b` for full API arguments. +""" + +def gemma2_27b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 27B model initialized w/ the default 27b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma2 27B model + """ + return gemma2( + vocab_size=256_000, + num_layers=46, + num_heads=32, + head_dim=128, + num_kv_heads=16, + embed_dim=4608, + intermediate_dim=36864, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + query_pre_attn_scalar=144, + ) + + +def lora_gemma2_27b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma2 27B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 27B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=46, + num_heads=32, + head_dim=128, + num_kv_heads=16, + embed_dim=4608, + intermediate_dim=36864, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + query_pre_attn_scalar=144, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_27b = partial(lora_gemma2_27b, quantize_base=True) + +qlora_gemma2_27b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma2_27b` for full API arguments. +""" diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 2dfeaddc9a..2fbd4d23e7 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -8,6 +8,7 @@ from typing import Optional import torch +import torch.nn.functional as F from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -306,3 +307,291 @@ def forward( # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output) + + +class Gemma2Attention(nn.Module): + """ + Adapated from official Google Pytorch Implementation: + https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 + to match torchtune style. + A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied + softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed + query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50.0, + query_pre_attn_scalar: Optional[int] = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # gemma related parameters + self.sliding_window_size = sliding_window_size + self.softcapping = softcapping + if query_pre_attn_scalar is not None: + self.scaling = query_pre_attn_scalar**-0.5 + else: + self.scaling = self.head_dim**-0.5 + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # [b, n_h, s_x, h_d] + q = q.transpose(1, 2) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [b, s_y, n_kv, 1, h_d] + # v: [b, s_y, n_kv, 1, h_d] + k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + + # [b, s, n_h, h_d] + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update(k, v) + + q.mul_(self.scaling) + output = torch.matmul(q, k.transpose(2, 3)) + + if self.sliding_window_size is not None: + all_ones = torch.ones_like(mask) + sliding_mask = torch.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * torch.tril(all_ones, self.sliding_window_size - 1) + mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) + + if self.softcapping is not None: + output = output / self.softcapping + output = torch.tanh(output) + output = output * self.softcapping + + output = output + mask + output = F.softmax(output.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(output, v) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 2d353b007c..0eb4c7ebdf 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -45,6 +45,7 @@ class ModelType(Enum): Attributes: GEMMA (str): Gemma family of models. See :func:`~torchtune.models.gemma.gemma` + GEMMA2 (str): Gemma family of models. See :func:`~torchtune.models.gemma2.gemma2` LLAMA2 (str): Llama2 family of models. See :func:`~torchtune.models.llama2.llama2` LLAMA3 (str): Llama3 family of models. See :func:`~torchtune.models.llama3.llama3` LLAMA3_2 (str): Llama3.2 family of models. See :func:`~torchtune.models.llama3_2.llama3_2` @@ -65,6 +66,7 @@ class ModelType(Enum): """ GEMMA: str = "gemma" + GEMMA2: str = "gemma2" LLAMA2: str = "llama2" LLAMA3: str = "llama3" LLAMA3_2: str = "llama3_2" From e87f8787360152c0f27b5710871b7c93ce25523d Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 18 Oct 2024 14:39:00 +0200 Subject: [PATCH 02/11] WIP: working gemma2 2b pipeline --- docs/source/api_ref_models.rst | 2 +- recipes/configs/gemma2/27B_full.yaml | 29 +- recipes/configs/gemma2/27B_lora.yaml | 29 +- .../gemma2/27B_lora_single_device.yaml | 29 +- .../gemma2/27B_qlora_single_device.yaml | 29 +- recipes/configs/gemma2/9B_full.yaml | 13 +- recipes/configs/gemma2/9B_lora.yaml | 13 +- .../configs/gemma2/9B_lora_single_device.yaml | 13 +- .../gemma2/9B_qlora_single_device.yaml | 13 +- torchtune/models/convert_weights.py | 6 +- torchtune/models/gemma2/_attention.py | 305 ++++++++++++++++++ .../models/gemma2/_component_builders.py | 20 +- torchtune/models/gemma2/_convert_weights.py | 132 ++++++++ torchtune/models/gemma2/_model_builders.py | 20 +- torchtune/modules/attention.py | 289 ----------------- .../training/checkpointing/_checkpointer.py | 20 ++ 16 files changed, 505 insertions(+), 457 deletions(-) create mode 100644 torchtune/models/gemma2/_attention.py create mode 100644 torchtune/models/gemma2/_convert_weights.py diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index 39d6ec2291..8bd039805a 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -320,7 +320,7 @@ To download the Gemma 7B model: gemma.gemma_tokenizer -gemma2 : # TODO +gemma2 : ----- Models of size 2B, 9B, 27B from the `Gemma family `_. diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index 09cf4bbc38..17a6e895f5 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -34,32 +34,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-27b/ - checkpoint_files: [ - model-00001-of-00024.safetensors, - model-00002-of-00024.safetensors, - model-00003-of-00024.safetensors, - model-00004-of-00024.safetensors, - model-00005-of-00024.safetensors, - model-00006-of-00024.safetensors, - model-00007-of-00024.safetensors, - model-00008-of-00024.safetensors, - model-00009-of-00024.safetensors, - model-00010-of-00024.safetensors, - model-00011-of-00024.safetensors, - model-00012-of-00024.safetensors, - model-00013-of-00024.safetensors, - model-00014-of-00024.safetensors, - model-00015-of-00024.safetensors, - model-00016-of-00024.safetensors, - model-00017-of-00024.safetensors, - model-00018-of-00024.safetensors, - model-00019-of-00024.safetensors, - model-00020-of-00024.safetensors, - model-00021-of-00024.safetensors, - model-00022-of-00024.safetensors, - model-00023-of-00024.safetensors, - model-00024-of-00024.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00024 recipe_checkpoint: null output_dir: /tmp/gemma2-27b model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml index 3631675a18..8cc22e4dd1 100644 --- a/recipes/configs/gemma2/27B_lora.yaml +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -39,32 +39,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-27b/ - checkpoint_files: [ - model-00001-of-00024.safetensors, - model-00002-of-00024.safetensors, - model-00003-of-00024.safetensors, - model-00004-of-00024.safetensors, - model-00005-of-00024.safetensors, - model-00006-of-00024.safetensors, - model-00007-of-00024.safetensors, - model-00008-of-00024.safetensors, - model-00009-of-00024.safetensors, - model-00010-of-00024.safetensors, - model-00011-of-00024.safetensors, - model-00012-of-00024.safetensors, - model-00013-of-00024.safetensors, - model-00014-of-00024.safetensors, - model-00015-of-00024.safetensors, - model-00016-of-00024.safetensors, - model-00017-of-00024.safetensors, - model-00018-of-00024.safetensors, - model-00019-of-00024.safetensors, - model-00020-of-00024.safetensors, - model-00021-of-00024.safetensors, - model-00022-of-00024.safetensors, - model-00023-of-00024.safetensors, - model-00024-of-00024.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00024 recipe_checkpoint: null output_dir: /tmp/gemma2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 58e7170c8c..11ca14eceb 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -38,32 +38,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-27b/ - checkpoint_files: [ - model-00001-of-00024.safetensors, - model-00002-of-00024.safetensors, - model-00003-of-00024.safetensors, - model-00004-of-00024.safetensors, - model-00005-of-00024.safetensors, - model-00006-of-00024.safetensors, - model-00007-of-00024.safetensors, - model-00008-of-00024.safetensors, - model-00009-of-00024.safetensors, - model-00010-of-00024.safetensors, - model-00011-of-00024.safetensors, - model-00012-of-00024.safetensors, - model-00013-of-00024.safetensors, - model-00014-of-00024.safetensors, - model-00015-of-00024.safetensors, - model-00016-of-00024.safetensors, - model-00017-of-00024.safetensors, - model-00018-of-00024.safetensors, - model-00019-of-00024.safetensors, - model-00020-of-00024.safetensors, - model-00021-of-00024.safetensors, - model-00022-of-00024.safetensors, - model-00023-of-00024.safetensors, - model-00024-of-00024.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00024 recipe_checkpoint: null output_dir: /tmp/gemma2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml index 87c8b25f8c..9f612cc3c0 100644 --- a/recipes/configs/gemma2/27B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -38,32 +38,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-27b/ - checkpoint_files: [ - model-00001-of-00024.safetensors, - model-00002-of-00024.safetensors, - model-00003-of-00024.safetensors, - model-00004-of-00024.safetensors, - model-00005-of-00024.safetensors, - model-00006-of-00024.safetensors, - model-00007-of-00024.safetensors, - model-00008-of-00024.safetensors, - model-00009-of-00024.safetensors, - model-00010-of-00024.safetensors, - model-00011-of-00024.safetensors, - model-00012-of-00024.safetensors, - model-00013-of-00024.safetensors, - model-00014-of-00024.safetensors, - model-00015-of-00024.safetensors, - model-00016-of-00024.safetensors, - model-00017-of-00024.safetensors, - model-00018-of-00024.safetensors, - model-00019-of-00024.safetensors, - model-00020-of-00024.safetensors, - model-00021-of-00024.safetensors, - model-00022-of-00024.safetensors, - model-00023-of-00024.safetensors, - model-00024-of-00024.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00024 recipe_checkpoint: null output_dir: /tmp/gemma2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index 09d638a3b9..aa746080a0 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -34,16 +34,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-9b/ - checkpoint_files: [ - model-00001-of-00008.safetensors, - model-00002-of-00008.safetensors, - model-00003-of-00008.safetensors, - model-00004-of-00008.safetensors, - model-00005-of-00008.safetensors, - model-00006-of-00008.safetensors, - model-00007-of-00008.safetensors, - model-00008-of-00008.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00008 recipe_checkpoint: null output_dir: /tmp/gemma2-9b model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml index 3f27bab651..f1cc3e3337 100644 --- a/recipes/configs/gemma2/9B_lora.yaml +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -39,16 +39,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-9b/ - checkpoint_files: [ - model-00001-of-00008.safetensors, - model-00002-of-00008.safetensors, - model-00003-of-00008.safetensors, - model-00004-of-00008.safetensors, - model-00005-of-00008.safetensors, - model-00006-of-00008.safetensors, - model-00007-of-00008.safetensors, - model-00008-of-00008.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00008 recipe_checkpoint: null output_dir: /tmp/gemma2-9b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml index 73ee146089..cb5461a5e3 100644 --- a/recipes/configs/gemma2/9B_lora_single_device.yaml +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -38,16 +38,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-9b/ - checkpoint_files: [ - model-00001-of-00008.safetensors, - model-00002-of-00008.safetensors, - model-00003-of-00008.safetensors, - model-00004-of-00008.safetensors, - model-00005-of-00008.safetensors, - model-00006-of-00008.safetensors, - model-00007-of-00008.safetensors, - model-00008-of-00008.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00008 recipe_checkpoint: null output_dir: /tmp/gemma2-9b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml index 6ef9a5d785..38a158a2f6 100644 --- a/recipes/configs/gemma2/9B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -38,16 +38,9 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma2-9b/ - checkpoint_files: [ - model-00001-of-00008.safetensors, - model-00002-of-00008.safetensors, - model-00003-of-00008.safetensors, - model-00004-of-00008.safetensors, - model-00005-of-00008.safetensors, - model-00006-of-00008.safetensors, - model-00007-of-00008.safetensors, - model-00008-of-00008.safetensors, - ] + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: 00008 recipe_checkpoint: null output_dir: /tmp/gemma2-9b/ model_type: GEMMA2 diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index 7333af1838..c0cf2f10fc 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -38,10 +38,8 @@ "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", # mlp_norm.scale -> looks like a previous bug here # noqa - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.sa_scale.scale", - "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.mlp_norm.scale", - "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.mlp_scale.scale", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", "model.norm.weight": "norm.scale", "lm_head.weight": "output.weight", } diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py new file mode 100644 index 0000000000..c83212f7b5 --- /dev/null +++ b/torchtune/models/gemma2/_attention.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torchtune.modules.attention_utils import _MaskType +from torchtune.modules.kv_cache import KVCache + + +logger = logging.getLogger(__name__) + + +class Gemma2Attention(nn.Module): + """ + Adapated from official Google Pytorch Implementation: + https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 + to match torchtune style. + A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied + softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed + query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50.0, + query_pre_attn_scalar: Optional[int] = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # gemma related parameters + self.sliding_window_size = sliding_window_size + self.softcapping = softcapping + if query_pre_attn_scalar is not None: + self.scaling = query_pre_attn_scalar**-0.5 + else: + self.scaling = self.head_dim**-0.5 + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # [b, n_h, s_x, h_d] + q = q.transpose(1, 2) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [b, s_y, n_kv, 1, h_d] + # v: [b, s_y, n_kv, 1, h_d] + k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + + # [b, s, n_h, h_d] + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update(k, v) + + q.mul_(self.scaling) + output = torch.matmul(q, k.transpose(2, 3)) + + if self.sliding_window_size is not None: + all_ones = torch.ones_like(mask) + sliding_mask = torch.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * torch.tril(all_ones, self.sliding_window_size - 1) + mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) + + if self.softcapping is not None: + output = output / self.softcapping + output = torch.tanh(output) + output = output * self.softcapping + + output = output + mask + output = F.softmax(output.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(output, v) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 6c99ccb701..6478d8ec31 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -16,7 +16,7 @@ TransformerSelfAttentionLayer, ) -from torchtune.modules.attention import Gemma2Attention +from torchtune.models.gemma2._attention import Gemma2Attention from torchtune.models.gemma.rms_norm import GemmaRMSNorm from torchtune.modules import TransformerDecoder, TiedLinear from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings @@ -35,7 +35,7 @@ the building blocks simple. """ -class TanhSotfCapping(nn.Module): +class TanhSoftCapping(nn.Module): def __init__( self, capping_value: float, @@ -62,7 +62,7 @@ def __init__( super().__init__() self.capping_value = capping_value self.rms_norm = GemmaRMSNorm(embed_dim, eps=eps) - self.logit_capping = TanhSotfCapping(capping_value) + self.logit_capping = TanhSoftCapping(capping_value) def forward(self, x): x = self.rms_norm(x) @@ -246,21 +246,23 @@ def lora_gemma2( for layer_idx in range(num_layers): self_att = lora_gemma2_self_attention( + lora_modules=lora_attn_modules, embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, - q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), - kv_cache=None, + rope_base=rope_base, max_seq_len=max_seq_len, attn_dropout=attn_dropout, # perform sliding window on half of the layers only sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, softcapping=hidden_capping_value, - query_pre_attn_scalar=query_pre_attn_scalar + query_pre_attn_scalar=query_pre_attn_scalar, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora = use_dora, + quantize_base = quantize_base, ) layer = TransformerSelfAttentionLayer( diff --git a/torchtune/models/gemma2/_convert_weights.py b/torchtune/models/gemma2/_convert_weights.py new file mode 100644 index 0000000000..fa4df0e469 --- /dev/null +++ b/torchtune/models/gemma2/_convert_weights.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +""" +Gemma 2 and Gemma original implementations share different normalization but with +the same name, so it is mandatory to differentiate their state dict in order to map +correctly the different weights. +They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key. +See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251 +""" + +_GEMMA2_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.sa_scale.scale", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.mlp_scale.scale", + "model.norm.weight": "norm.rms_norm.scale", + "lm_head.weight": "output.weight", +} + + +def gemma2_hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to torchtune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + + Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings + new_key = get_mapped_key(key, _GEMMA2_FROM_HF) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + + converted_state_dict[new_key] = value + return converted_state_dict + + +def gemma2_tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +): + """ + Convert a state dict from torchtune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of model attention heads. Default None. + + Returns: + Dict[str, torch.Tensor]: State dict in HF's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _GEMMA2_FROM_HF.items()} + + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/models/gemma2/_model_builders.py b/torchtune/models/gemma2/_model_builders.py index 72df7747da..a07021c518 100644 --- a/torchtune/models/gemma2/_model_builders.py +++ b/torchtune/models/gemma2/_model_builders.py @@ -27,12 +27,12 @@ def gemma2_2b() -> TransformerDecoder: """ return gemma2( vocab_size=256_000, - num_layers=18, + num_layers=26, num_heads=8, head_dim=256, - num_kv_heads=1, - embed_dim=2048, - intermediate_dim=16384, + num_kv_heads=4, + embed_dim=2304, + intermediate_dim=9216, max_seq_len=8192, attn_dropout=0.0, norm_eps=1e-6, @@ -78,12 +78,12 @@ def lora_gemma2_2b( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, vocab_size=256_000, - num_layers=18, + num_layers=26, num_heads=8, head_dim=256, - num_kv_heads=1, - embed_dim=2048, - intermediate_dim=16384, + num_kv_heads=4, + embed_dim=2304, + intermediate_dim=9216, max_seq_len=8192, attn_dropout=0.0, norm_eps=1e-6, @@ -120,7 +120,7 @@ def gemma2_9b() -> TransformerDecoder: num_layers=42, num_heads=16, head_dim=256, - num_kv_heads=16, + num_kv_heads=8, embed_dim=3584, intermediate_dim=14336, max_seq_len=8192, @@ -171,7 +171,7 @@ def lora_gemma2_9b( num_layers=42, num_heads=16, head_dim=256, - num_kv_heads=16, + num_kv_heads=8, embed_dim=3584, intermediate_dim=14336, max_seq_len=8192, diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 2fbd4d23e7..2dfeaddc9a 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -8,7 +8,6 @@ from typing import Optional import torch -import torch.nn.functional as F from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -307,291 +306,3 @@ def forward( # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output) - - -class Gemma2Attention(nn.Module): - """ - Adapated from official Google Pytorch Implementation: - https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 - to match torchtune style. - A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping - Args: - embed_dim (int): embedding dimension for the model - num_heads (int): number of query heads. For MHA this is also the - number of heads for key and value - num_kv_heads (int): number of key and value heads. User should ensure - ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, - for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. - head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. - q_proj (nn.Module): projection layer for query. - k_proj (nn.Module): projection layer for key. - v_proj (nn.Module): projection layer for value. - output_proj (nn.Module): projection layer for output. - pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. - q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied - before updating from kv_cache. This means it will only support token wide normalization and not - batch or sequence wide normalization. - k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. - kv_cache (Optional[KVCache]): KVCache object used to cache key and value - max_seq_len (int): maximum sequence length supported by the model. - This is needed to compute the RoPE Cache. Default: 4096. - is_causal (bool): sets the default mask to causal when no mask is provided - attn_dropout (float): dropout value passed onto the - scaled_dot_product_attention function. This argument is ignored if the - self.training is False. Default value is 0.0. - sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied - softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed - query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead - Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa - """ - - def __init__( - self, - *, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - q_proj: nn.Module, - k_proj: nn.Module, - v_proj: nn.Module, - output_proj: nn.Module, - pos_embeddings: Optional[nn.Module] = None, - q_norm: Optional[nn.Module] = None, - k_norm: Optional[nn.Module] = None, - kv_cache: Optional[KVCache] = None, - max_seq_len: int = 4096, - is_causal: bool = True, - attn_dropout: float = 0.0, - sliding_window_size: Optional[int] = None, - softcapping: Optional[float] = 50.0, - query_pre_attn_scalar: Optional[int] = None, - ) -> None: - super().__init__() - if num_heads % num_kv_heads != 0: - raise ValueError( - f"num_heads ({num_heads}) must be divisible by " - f"num_kv_heads ({num_kv_heads})" - ) - - if embed_dim % num_heads != 0: - raise ValueError( - f"embed_dim ({embed_dim}) must be divisible by " - f"num_heads ({num_heads})" - ) - - if attn_dropout < 0 or attn_dropout > 1: - raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") - - if bool(q_norm) ^ bool(k_norm): - raise ValueError("q and k norm must be set together") - - # Set attributes - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.embed_dim = embed_dim - self.attn_dropout = attn_dropout - self.head_dim = head_dim - self.max_seq_len = max_seq_len - self.is_causal = is_causal - - # Set layers - self.kv_cache = kv_cache - self.q_proj = q_proj - self.k_proj = k_proj - self.v_proj = v_proj - self.output_proj = output_proj - self.q_norm = q_norm - self.k_norm = k_norm - self.pos_embeddings = pos_embeddings - - # gemma related parameters - self.sliding_window_size = sliding_window_size - self.softcapping = softcapping - if query_pre_attn_scalar is not None: - self.scaling = query_pre_attn_scalar**-0.5 - else: - self.scaling = self.head_dim**-0.5 - - def setup_cache( - self, batch_size: int, dtype: torch.dtype, max_seq_len: int - ) -> None: - """Setup key value caches for attention calculation. If called - after kv_cache is already setup, this will be skipped. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - max_seq_len (int): maximum sequence length model will be run with. - """ - # Don't overwrite user defined kv_cache from init - if self.kv_cache is not None: - logger.warning( - "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." - ) - else: - self.kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=max_seq_len, - num_heads=self.num_heads, - head_dim=self.head_dim, - dtype=dtype, - ) - - def reset_cache(self): - """Reset the key value caches.""" - if self.kv_cache is None: - raise RuntimeError( - "Key value caches are not setup. Call ``setup_caches()`` first." - ) - self.kv_cache.reset() - - def forward( - self, - x: torch.Tensor, - y: Optional[torch.Tensor] = None, - *, - mask: Optional[_MaskType] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Args: - x (torch.Tensor): input tensor with shape [b x s_x x d] for the query - y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input - for k and v. For self attention, x=y. Optional only with kv_cache enabled. - mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication - and before the softmax. Either: - - A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, - or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. - A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means - token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask - is used by default. - - A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence - created via `create_block_mask `_. We use - :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. - Default is None. - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape [b x s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - - Raises: - ValueError: If no ``y`` input and ``kv_cache`` is not enabled. - - Returns: - torch.Tensor: output tensor with attention applied - - Notation used for tensor shapes: - - b: batch size - - s_x: sequence length for x - - s_y: sequence length for y - - n_h: num heads - - n_kv: num kv heads - - d: embed dim - - h_d: head dim - """ - # x has shape [b, s_x, d] - # y has shape [b, s_y, d] - b, s_x, _ = x.shape - s_y = y.shape[1] if y is not None else 0 - - # q has shape [b, s_x, num_heads * head_dim] - q = self.q_proj(x) - - # number of queries per key/value - q_per_kv = self.num_heads // self.num_kv_heads - q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) - - # Apply positional embeddings - if self.pos_embeddings is not None: - q = self.pos_embeddings(q, input_pos=input_pos) - - # [b, n_h, s_x, h_d] - q = q.transpose(1, 2) - - # Normalize q - if self.q_norm is not None: - q = self.q_norm(q) - - if y is None: - if self.kv_cache is None: - raise ValueError( - "Must provide y input or use kv_cache to enable streaming decoding" - ) - k = self.kv_cache.k_cache - v = self.kv_cache.v_cache - else: - # Update k and v shape, positional embeddings, and normalization - - # k has shape [b, s_y, num_kv_heads * head_dim] - # v has shape [b, s_y, num_kv_heads * head_dim] - k = self.k_proj(y) - v = self.v_proj(y) - - # Apply positional embeddings - # k: [b, s_y, n_kv, h_d] - k = k.view(b, s_y, -1, self.head_dim) - if self.pos_embeddings is not None: - k = self.pos_embeddings(k, input_pos=input_pos) - - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - - # If needed, expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Normalize k - if self.k_norm is not None: - k = self.k_norm(k) - - # Update key-value cache - if self.kv_cache is not None: - k, v = self.kv_cache.update(k, v) - - q.mul_(self.scaling) - output = torch.matmul(q, k.transpose(2, 3)) - - if self.sliding_window_size is not None: - all_ones = torch.ones_like(mask) - sliding_mask = torch.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * torch.tril(all_ones, self.sliding_window_size - 1) - mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) - - if self.softcapping is not None: - output = output / self.softcapping - output = torch.tanh(output) - output = output * self.softcapping - - output = output + mask - output = F.softmax(output.float(), dim=-1).type_as(q) - - # [batch_size, n_local_heads, input_len, head_dim] - output = torch.matmul(output, v) - - # reshape the output to be the same shape as the input - output = output.transpose(1, 2).contiguous().view(b, s_x, -1) - return self.output_proj(output) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 75a7fc950d..18ecc86946 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -488,6 +488,16 @@ def load_checkpoint(self) -> Dict[str, Any]: "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.GEMMA2: + from torchtune.models.gemma2._convert_weights import gemma2_hf_to_tune + + converted_state_dict[training.MODEL_KEY] = gemma2_hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -578,6 +588,16 @@ def save_checkpoint( "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.GEMMA2: + from torchtune.models.gemma2._convert_weights import gemma2_tune_to_hf + + state_dict[training.MODEL_KEY] = gemma2_tune_to_hf( + state_dict[training.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], From 6f89920c35b2a48b68267265c40ea7bb0c65c7e9 Mon Sep 17 00:00:00 2001 From: Optimox Date: Tue, 22 Oct 2024 11:05:47 +0200 Subject: [PATCH 03/11] WIP: non working flex attention --- recipes/configs/gemma2/2B_full.yaml | 6 +- recipes/configs/gemma2/2B_lora.yaml | 6 +- .../configs/gemma2/2B_lora_single_device.yaml | 8 +- .../gemma2/2B_qlora_single_device.yaml | 6 +- recipes/lora_finetune_single_device.py | 1 - torchtune/models/gemma2/_attention.py | 309 +++++++++++++++++- torchtune/models/gemma2/_attention_utils.py | 96 ++++++ .../models/gemma2/_component_builders.py | 26 +- 8 files changed, 438 insertions(+), 20 deletions(-) create mode 100644 torchtune/models/gemma2/_attention_utils.py diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index f1214810a9..9386fae4b9 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -33,14 +33,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml index ca6d8df232..e6ef6e6e9e 100644 --- a/recipes/configs/gemma2/2B_lora.yaml +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -37,14 +37,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index d8bbeb9a81..49b59846c4 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -44,7 +44,7 @@ checkpointer: model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False @@ -62,10 +62,10 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments -batch_size: 4 +batch_size: 8 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 2 compile: False # Training env diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml index c65367419f..b5d7c9147d 100644 --- a/recipes/configs/gemma2/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -37,14 +37,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 5d39b72086..4f567e2c9a 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -616,7 +616,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - # run model with self.activations_handling_ctx: logits = self._model(**batch) diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index c83212f7b5..e4d0949d0f 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -12,8 +12,15 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType from torchtune.modules.kv_cache import KVCache - - +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +if _SUPPORTS_FLEX_ATTENTION: + from torch.nn.attention.flex_attention import create_block_mask + from torchtune.models.gemma2._attention_utils import ( + compile_friendly_flex_attention_with_score_and_block, + flex_causal_sliding_window, + flex_tanh_soft_capping_with_scaling, + ) logger = logging.getLogger(__name__) @@ -282,8 +289,18 @@ def forward( q.mul_(self.scaling) output = torch.matmul(q, k.transpose(2, 3)) + # if mask is None: default to causal mask + if mask is None: + mask = torch.tril( + torch.ones( + size=(s_x, s_x), + dtype=torch.bool, + ).to(x.device) + ) + if self.sliding_window_size is not None: all_ones = torch.ones_like(mask) + sliding_mask = torch.triu( all_ones, -1 * self.sliding_window_size + 1 ) * torch.tril(all_ones, self.sliding_window_size - 1) @@ -303,3 +320,291 @@ def forward( # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output) + + +class FlexGemma2Attention(nn.Module): + """ + Adapated from official Google Pytorch Implementation: + https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 + to match torchtune style. + A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied + softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed + query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50.0, + query_pre_attn_scalar: Optional[int] = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # gemma related parameters + self.sliding_window_size = sliding_window_size + self.softcapping = softcapping + if query_pre_attn_scalar is not None: + # flex attention will always make the head_dim**-0.5 normalization so it should be included in scaling + self.scaling = query_pre_attn_scalar**-0.5 / self.head_dim**-0.5 + else: + self.scaling = None + + self.mask_mod = flex_causal_sliding_window(self.sliding_window_size) + self.score_mod = flex_tanh_soft_capping_with_scaling( + self.softcapping, self.scaling + ) + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # [b, n_h, s_x, h_d] + q = q.transpose(1, 2) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [b, s_y, n_kv, 1, h_d] + # v: [b, s_y, n_kv, 1, h_d] + k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + + # [b, s, n_h, h_d] + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update(k, v) + + # TODO: how to avoid to compute same block mask at every layer ? + # https://pytorch.org/blog/flexattention/#q-when-should-we-recompute-the-blockmask + block_mask = create_block_mask( + mask_mod=self.mask_mod, + B=b, + H=self.num_heads, + Q_LEN=s_x, + KV_LEN=s_x, + device=q.device, + ) + + output = compile_friendly_flex_attention_with_score_and_block( + q, k, v, score_mod=self.score_mod, block_mask=block_mask + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/gemma2/_attention_utils.py b/torchtune/models/gemma2/_attention_utils.py new file mode 100644 index 0000000000..534ad9e051 --- /dev/null +++ b/torchtune/models/gemma2/_attention_utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch + +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +if _SUPPORTS_FLEX_ATTENTION: + from functools import lru_cache + + from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, + ) + + # flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + + @lru_cache + def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"): + block_mask = create_block_mask(score_mod, b, h, m, n, device=device) + return block_mask + + # We cannot do nested compile, but flex attention only has perf benefits + # when compiled. To insulate it from the compiler, we wrap it with + # compiler.disable so that it can be used regardless of whether the model + # is compiled or not, and flex attention always remains compiled. + @torch.compiler.disable(recursive=False) + def compile_friendly_flex_attention_with_score_and_block( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_mask: BlockMask, + score_mod: Any, + ) -> torch.Tensor: + """ + Flex attention does not seem to work with my A6000 with the default options. + Using proposed options here: https://github.com/pytorch/pytorch/issues/133254 + """ + return flex_attention( + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + # kernel_options={ + # "BLOCK_M": 64, + # "BLOCK_N": 64, + # "BLOCK_M1": 32, + # "BLOCK_N1": 64, + # "BLOCK_M2": 64, + # "BLOCK_N2": 32, + # }, + ) + + +def flex_causal_sliding_window(sliding_window_size): + def sliding_window_causal_mask(b, h, q_idx, kv_idx): + """Causal mask and sliding window as proposed here: + https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb + """ + causal_mask = q_idx >= kv_idx + if sliding_window_size is None: + # if no sliding window return causal mask + return causal_mask + else: + windowed_mask = q_idx - kv_idx <= sliding_window_size + + return causal_mask & windowed_mask + + return sliding_window_causal_mask + + +def flex_tanh_soft_capping_with_scaling(softcapping, query_pre_attn_scalar): + def tanh_soft_capping_with_scaling(score, b, h, q_idx, kv_idx): + """ + This handle both simple tanh soft capping and custom scaling + """ + if query_pre_attn_scalar is None: + # usual scaling included in FlexAttention + # TODO: could be made faster with approximate tanh ? + # https://github.com/pytorch-labs/attention-gym/blob/f7c93ded4abf9fd8d7dc9d8bcbf57e420b891e2d/examples/flex_attn.ipynb#L733 + score = score / softcapping + score = torch.tanh(score) + return score * softcapping + else: + score = score / softcapping * query_pre_attn_scalar**-0.5 + score = torch.tanh(score) + return score * softcapping + + return tanh_soft_capping_with_scaling diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 6478d8ec31..915430ce4a 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -22,7 +22,24 @@ from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +import logging +from torchtune.utils._logging import get_logger, log_once + +_log: logging.Logger = get_logger() + + +if _SUPPORTS_FLEX_ATTENTION: + from torchtune.models.gemma2._attention import FlexGemma2Attention + log_once( + _log, + "Using flex attention for Gemma2 attention computation.", + level=logging.DEBUG, + ) + _flex_or_native_gemma2_attention = FlexGemma2Attention +else: + _flex_or_native_gemma2_attention = Gemma2Attention """ Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. @@ -47,7 +64,7 @@ def forward(self, attn_weights): attn_weights = attn_weights / self.capping_value attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.capping_value - + return attn_weights class Gemma2FinalNorm(nn.Module): """ @@ -120,7 +137,8 @@ def gemma2( layers = torch.nn.ModuleList() for layer_idx in range(num_layers): - self_att = Gemma2Attention( + + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, @@ -316,7 +334,7 @@ def lora_gemma2_self_attention( use_dora: bool = False, quantize_base: bool = False, -) -> Gemma2Attention: +) -> _flex_or_native_gemma2_attention: if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" @@ -392,7 +410,7 @@ def lora_gemma2_self_attention( rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_att = Gemma2Attention( + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, From 0d5366012a44f269beada3b9b120532c73946749 Mon Sep 17 00:00:00 2001 From: Optimox Date: Thu, 24 Oct 2024 15:56:53 +0200 Subject: [PATCH 04/11] update recipes for 9b and 27b --- recipes/configs/gemma2/27B_full.yaml | 6 +++--- recipes/configs/gemma2/27B_lora.yaml | 6 +++--- recipes/configs/gemma2/27B_lora_single_device.yaml | 6 +++--- recipes/configs/gemma2/27B_qlora_single_device.yaml | 6 +++--- recipes/configs/gemma2/9B_full.yaml | 6 +++--- recipes/configs/gemma2/9B_lora.yaml | 6 +++--- recipes/configs/gemma2/9B_lora_single_device.yaml | 6 +++--- recipes/configs/gemma2/9B_qlora_single_device.yaml | 8 ++++---- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index 17a6e895f5..eebeefbd4f 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-27b/tokenizer.model + path: /tmp/gemma-2-27b/tokenizer.model # Dataset dataset: @@ -33,12 +33,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00024 recipe_checkpoint: null - output_dir: /tmp/gemma2-27b + output_dir: /tmp/gemma-2-27b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml index 8cc22e4dd1..e78b40633a 100644 --- a/recipes/configs/gemma2/27B_lora.yaml +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-27b/tokenizer.model + path: /tmp/gemma-2-27b/tokenizer.model # Dataset dataset: @@ -38,12 +38,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00024 recipe_checkpoint: null - output_dir: /tmp/gemma2-27b/ + output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 11ca14eceb..7879dd1fce 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-27b/tokenizer.model + path: /tmp/gemma-2-27b/tokenizer.model # Dataset dataset: @@ -37,12 +37,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00024 recipe_checkpoint: null - output_dir: /tmp/gemma2-27b/ + output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml index 9f612cc3c0..a1b7fcd377 100644 --- a/recipes/configs/gemma2/27B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-27b/tokenizer.model + path: /tmp/gemma-2-27b/tokenizer.model # Dataset dataset: @@ -37,12 +37,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-27b/ + checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00024 recipe_checkpoint: null - output_dir: /tmp/gemma2-27b/ + output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index aa746080a0..d599970a2a 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-9b/tokenizer.model + path: /tmp/gemma-2-9b/tokenizer.model # Dataset dataset: @@ -33,12 +33,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00008 recipe_checkpoint: null - output_dir: /tmp/gemma2-9b + output_dir: /tmp/gemma-2-9b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml index f1cc3e3337..1cf209a249 100644 --- a/recipes/configs/gemma2/9B_lora.yaml +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-9b/tokenizer.model + path: /tmp/gemma-2-9b/tokenizer.model # Dataset dataset: @@ -38,12 +38,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00008 recipe_checkpoint: null - output_dir: /tmp/gemma2-9b/ + output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml index cb5461a5e3..57d066bb0a 100644 --- a/recipes/configs/gemma2/9B_lora_single_device.yaml +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-9b/tokenizer.model + path: /tmp/gemma-2-9b/tokenizer.model # Dataset dataset: @@ -37,12 +37,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00008 recipe_checkpoint: null - output_dir: /tmp/gemma2-9b/ + output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml index 38a158a2f6..3c198bead8 100644 --- a/recipes/configs/gemma2/9B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-9b/tokenizer.model + path: /tmp/gemma-2-9b/tokenizer.model # Dataset dataset: @@ -28,7 +28,7 @@ shuffle: True # Model Arguments model: - _component_: torchtune.models.gemma2.qlora_gemma_9b + _component_: torchtune.models.gemma2.qlora_gemma2_9b lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 @@ -37,12 +37,12 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-9b/ + checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors max_filename: 00008 recipe_checkpoint: null - output_dir: /tmp/gemma2-9b/ + output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False From 6b50916c0ea28d3bdedd29f627f6acefeb12fe38 Mon Sep 17 00:00:00 2001 From: Optimox Date: Sat, 26 Oct 2024 11:59:13 +0200 Subject: [PATCH 05/11] fix mlp and kv cache, disable flex attention --- .../gemma2/27B_lora_single_device.yaml | 4 +-- .../configs/gemma2/2B_lora_single_device.yaml | 2 +- torchtune/models/gemma2/_attention.py | 22 +++++++++++-- torchtune/models/gemma2/_attention_utils.py | 19 ++++++----- .../models/gemma2/_component_builders.py | 33 ++++++++++--------- 5 files changed, 50 insertions(+), 30 deletions(-) diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 7879dd1fce..56727e5290 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -60,10 +60,10 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments -batch_size: 8 +batch_size: 2 epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 +gradient_accumulation_steps: 8 compile: False # Training env diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index 49b59846c4..484f133b43 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -37,7 +37,7 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index e4d0949d0f..c769c5b479 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -12,7 +12,11 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType from torchtune.modules.kv_cache import KVCache -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False if _SUPPORTS_FLEX_ATTENTION: from torch.nn.attention.flex_attention import create_block_mask @@ -132,6 +136,11 @@ def __init__( else: self.scaling = self.head_dim**-0.5 + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -156,6 +165,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -283,7 +293,7 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: + if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) q.mul_(self.scaling) @@ -436,6 +446,11 @@ def __init__( self.softcapping, self.scaling ) + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -460,6 +475,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -587,7 +603,7 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: + if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) # TODO: how to avoid to compute same block mask at every layer ? diff --git a/torchtune/models/gemma2/_attention_utils.py b/torchtune/models/gemma2/_attention_utils.py index 534ad9e051..8a17c8ec86 100644 --- a/torchtune/models/gemma2/_attention_utils.py +++ b/torchtune/models/gemma2/_attention_utils.py @@ -8,7 +8,10 @@ import torch -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False if _SUPPORTS_FLEX_ATTENTION: from functools import lru_cache @@ -19,7 +22,9 @@ flex_attention, ) - # flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + flex_attention_compiled = torch.compile( + flex_attention, dynamic=False, mode="max-autotune" + ) @lru_cache def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"): @@ -40,7 +45,7 @@ def compile_friendly_flex_attention_with_score_and_block( ) -> torch.Tensor: """ Flex attention does not seem to work with my A6000 with the default options. - Using proposed options here: https://github.com/pytorch/pytorch/issues/133254 + Using proposed options here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1815058279 """ return flex_attention( q, @@ -49,12 +54,8 @@ def compile_friendly_flex_attention_with_score_and_block( score_mod=score_mod, block_mask=block_mask, # kernel_options={ - # "BLOCK_M": 64, - # "BLOCK_N": 64, - # "BLOCK_M1": 32, - # "BLOCK_N1": 64, - # "BLOCK_M2": 64, - # "BLOCK_N2": 32, + # "BLOCK_M": 32, + # "BLOCK_N": 32, # }, ) diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 915430ce4a..253cea2f87 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -22,7 +22,10 @@ from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False import logging from torchtune.utils._logging import get_logger, log_once @@ -132,12 +135,12 @@ def gemma2( """ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layers = torch.nn.ModuleList() for layer_idx in range(num_layers): + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, @@ -244,18 +247,6 @@ def lora_gemma2( TransformerDecoder: Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer. """ - if apply_lora_to_mlp: - mlp = lora_gemma_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) @@ -263,6 +254,18 @@ def lora_gemma2( layers = torch.nn.ModuleList() for layer_idx in range(num_layers): + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) self_att = lora_gemma2_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim, From 54a237cc1ce8177d0cc6f1b7d2be726df5a4da1e Mon Sep 17 00:00:00 2001 From: Optimox Date: Sun, 27 Oct 2024 10:21:32 +0100 Subject: [PATCH 06/11] update configs to match parallel PR --- recipes/configs/gemma2/27B_full.yaml | 4 +++- recipes/configs/gemma2/27B_lora.yaml | 4 +++- recipes/configs/gemma2/27B_lora_single_device.yaml | 5 +++-- recipes/configs/gemma2/27B_qlora_single_device.yaml | 5 +++-- recipes/configs/gemma2/2B_full.yaml | 4 +++- recipes/configs/gemma2/2B_lora.yaml | 4 +++- recipes/configs/gemma2/2B_lora_single_device.yaml | 5 +++-- recipes/configs/gemma2/2B_qlora_single_device.yaml | 5 +++-- recipes/configs/gemma2/9B_full.yaml | 4 +++- recipes/configs/gemma2/9B_lora.yaml | 4 +++- recipes/configs/gemma2/9B_lora_single_device.yaml | 5 +++-- recipes/configs/gemma2/9B_qlora_single_device.yaml | 5 +++-- torchtune/training/checkpointing/_utils.py | 2 +- 13 files changed, 37 insertions(+), 19 deletions(-) diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index eebeefbd4f..dee049024a 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -23,6 +23,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -53,6 +54,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -69,4 +71,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-27b-finetune log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml index e78b40633a..265895090d 100644 --- a/recipes/configs/gemma2/27B_lora.yaml +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -23,6 +23,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -65,6 +66,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -81,4 +83,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-27b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 56727e5290..e245aafa92 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -64,7 +65,7 @@ batch_size: 2 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 8 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -82,7 +83,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-27b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml index a1b7fcd377..2f0e7d6cad 100644 --- a/recipes/configs/gemma2/27B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -64,7 +65,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 4 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -82,7 +83,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-27b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index 9386fae4b9..e302dd759d 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -23,6 +23,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -55,6 +56,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -71,4 +73,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-finetune log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml index e6ef6e6e9e..9a439ee0a3 100644 --- a/recipes/configs/gemma2/2B_lora.yaml +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -67,6 +68,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -83,4 +85,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index 484f133b43..1a2703fb47 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -66,7 +67,7 @@ batch_size: 8 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 2 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -84,7 +85,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml index b5d7c9147d..c2525460ff 100644 --- a/recipes/configs/gemma2/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -66,7 +67,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 4 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -84,7 +85,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index d599970a2a..0002b1c3b9 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -23,6 +23,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -53,6 +54,7 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -69,4 +71,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-9b-finetune log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml index 1cf209a249..5b0141e9ef 100644 --- a/recipes/configs/gemma2/9B_lora.yaml +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -23,6 +23,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -65,6 +66,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -81,4 +83,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-9b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml index 57d066bb0a..197ee121ae 100644 --- a/recipes/configs/gemma2/9B_lora_single_device.yaml +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -64,7 +65,7 @@ batch_size: 8 epochs: 1 max_steps_per_epoch: null gradient_accumulation_steps: 2 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -82,7 +83,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-9b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml index 3c198bead8..80a3303104 100644 --- a/recipes/configs/gemma2/9B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -22,6 +22,7 @@ tokenizer: # Dataset dataset: + packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset seed: null shuffle: True @@ -64,7 +65,7 @@ batch_size: 4 epochs: 3 max_steps_per_epoch: null gradient_accumulation_steps: 4 -compile: False +compile: False # pytorch compile, set to true for perf/memory improvement # Training env device: cuda @@ -82,7 +83,7 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/alpaca-gemma2-9b-lora log_every_n_steps: 1 -log_peak_memory_stats: False +log_peak_memory_stats: True # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 0eb4c7ebdf..2fa7265194 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -45,7 +45,7 @@ class ModelType(Enum): Attributes: GEMMA (str): Gemma family of models. See :func:`~torchtune.models.gemma.gemma` - GEMMA2 (str): Gemma family of models. See :func:`~torchtune.models.gemma2.gemma2` + GEMMA2 (str): Gemma 2 family of models. See :func:`~torchtune.models.gemma2.gemma2` LLAMA2 (str): Llama2 family of models. See :func:`~torchtune.models.llama2.llama2` LLAMA3 (str): Llama3 family of models. See :func:`~torchtune.models.llama3.llama3` LLAMA3_2 (str): Llama3.2 family of models. See :func:`~torchtune.models.llama3_2.llama3_2` From 2c216de9c67f9a49f7a49c67645cd321b06ea7d3 Mon Sep 17 00:00:00 2001 From: Optimox Date: Wed, 30 Oct 2024 09:38:04 +0100 Subject: [PATCH 07/11] remove flex attention --- torchtune/models/gemma2/_attention.py | 306 ------------------ torchtune/models/gemma2/_attention_utils.py | 97 ------ .../models/gemma2/_component_builders.py | 26 +- 3 files changed, 3 insertions(+), 426 deletions(-) delete mode 100644 torchtune/models/gemma2/_attention_utils.py diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index c769c5b479..b944aae1f8 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -13,18 +13,6 @@ from torchtune.modules.attention_utils import _MaskType from torchtune.modules.kv_cache import KVCache -# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION -# The flex attention implementation for gemma2 is not working yet -# flex attention is disabled for now untill we solve the case -_SUPPORTS_FLEX_ATTENTION = False - -if _SUPPORTS_FLEX_ATTENTION: - from torch.nn.attention.flex_attention import create_block_mask - from torchtune.models.gemma2._attention_utils import ( - compile_friendly_flex_attention_with_score_and_block, - flex_causal_sliding_window, - flex_tanh_soft_capping_with_scaling, - ) logger = logging.getLogger(__name__) @@ -330,297 +318,3 @@ def forward( # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output) - - -class FlexGemma2Attention(nn.Module): - """ - Adapated from official Google Pytorch Implementation: - https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 - to match torchtune style. - A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping - Args: - embed_dim (int): embedding dimension for the model - num_heads (int): number of query heads. For MHA this is also the - number of heads for key and value - num_kv_heads (int): number of key and value heads. User should ensure - ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, - for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. - head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. - q_proj (nn.Module): projection layer for query. - k_proj (nn.Module): projection layer for key. - v_proj (nn.Module): projection layer for value. - output_proj (nn.Module): projection layer for output. - pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. - q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied - before updating from kv_cache. This means it will only support token wide normalization and not - batch or sequence wide normalization. - k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. - kv_cache (Optional[KVCache]): KVCache object used to cache key and value - max_seq_len (int): maximum sequence length supported by the model. - This is needed to compute the RoPE Cache. Default: 4096. - is_causal (bool): sets the default mask to causal when no mask is provided - attn_dropout (float): dropout value passed onto the - scaled_dot_product_attention function. This argument is ignored if the - self.training is False. Default value is 0.0. - sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied - softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed - query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead - Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa - """ - - def __init__( - self, - *, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - q_proj: nn.Module, - k_proj: nn.Module, - v_proj: nn.Module, - output_proj: nn.Module, - pos_embeddings: Optional[nn.Module] = None, - q_norm: Optional[nn.Module] = None, - k_norm: Optional[nn.Module] = None, - kv_cache: Optional[KVCache] = None, - max_seq_len: int = 4096, - is_causal: bool = True, - attn_dropout: float = 0.0, - sliding_window_size: Optional[int] = None, - softcapping: Optional[float] = 50.0, - query_pre_attn_scalar: Optional[int] = None, - ) -> None: - super().__init__() - if num_heads % num_kv_heads != 0: - raise ValueError( - f"num_heads ({num_heads}) must be divisible by " - f"num_kv_heads ({num_kv_heads})" - ) - - if embed_dim % num_heads != 0: - raise ValueError( - f"embed_dim ({embed_dim}) must be divisible by " - f"num_heads ({num_heads})" - ) - - if attn_dropout < 0 or attn_dropout > 1: - raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") - - if bool(q_norm) ^ bool(k_norm): - raise ValueError("q and k norm must be set together") - - # Set attributes - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.embed_dim = embed_dim - self.attn_dropout = attn_dropout - self.head_dim = head_dim - self.max_seq_len = max_seq_len - self.is_causal = is_causal - - # Set layers - self.kv_cache = kv_cache - self.q_proj = q_proj - self.k_proj = k_proj - self.v_proj = v_proj - self.output_proj = output_proj - self.q_norm = q_norm - self.k_norm = k_norm - self.pos_embeddings = pos_embeddings - - # gemma related parameters - self.sliding_window_size = sliding_window_size - self.softcapping = softcapping - if query_pre_attn_scalar is not None: - # flex attention will always make the head_dim**-0.5 normalization so it should be included in scaling - self.scaling = query_pre_attn_scalar**-0.5 / self.head_dim**-0.5 - else: - self.scaling = None - - self.mask_mod = flex_causal_sliding_window(self.sliding_window_size) - self.score_mod = flex_tanh_soft_capping_with_scaling( - self.softcapping, self.scaling - ) - - # this flag indicates whether to update the kv-cache during forward - # passes. when disabled, we can have the cache setup but still - # perform normal forward passes - self.cache_enabled = False - - def setup_cache( - self, batch_size: int, dtype: torch.dtype, max_seq_len: int - ) -> None: - """Setup key value caches for attention calculation. If called - after kv_cache is already setup, this will be skipped. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - max_seq_len (int): maximum sequence length model will be run with. - """ - # Don't overwrite user defined kv_cache from init - if self.kv_cache is not None: - logger.warning( - "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." - ) - else: - self.kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=max_seq_len, - num_heads=self.num_heads, - head_dim=self.head_dim, - dtype=dtype, - ) - self.cache_enabled = True - - def reset_cache(self): - """Reset the key value caches.""" - if self.kv_cache is None: - raise RuntimeError( - "Key value caches are not setup. Call ``setup_caches()`` first." - ) - self.kv_cache.reset() - - def forward( - self, - x: torch.Tensor, - y: Optional[torch.Tensor] = None, - *, - mask: Optional[_MaskType] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Args: - x (torch.Tensor): input tensor with shape [b x s_x x d] for the query - y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input - for k and v. For self attention, x=y. Optional only with kv_cache enabled. - mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication - and before the softmax. Either: - - A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, - or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. - A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means - token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask - is used by default. - - A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence - created via `create_block_mask `_. We use - :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. - Default is None. - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape [b x s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - - Raises: - ValueError: If no ``y`` input and ``kv_cache`` is not enabled. - - Returns: - torch.Tensor: output tensor with attention applied - - Notation used for tensor shapes: - - b: batch size - - s_x: sequence length for x - - s_y: sequence length for y - - n_h: num heads - - n_kv: num kv heads - - d: embed dim - - h_d: head dim - """ - # x has shape [b, s_x, d] - # y has shape [b, s_y, d] - b, s_x, _ = x.shape - s_y = y.shape[1] if y is not None else 0 - - # q has shape [b, s_x, num_heads * head_dim] - q = self.q_proj(x) - - # number of queries per key/value - q_per_kv = self.num_heads // self.num_kv_heads - q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) - - # Apply positional embeddings - if self.pos_embeddings is not None: - q = self.pos_embeddings(q, input_pos=input_pos) - - # [b, n_h, s_x, h_d] - q = q.transpose(1, 2) - - # Normalize q - if self.q_norm is not None: - q = self.q_norm(q) - - if y is None: - if self.kv_cache is None: - raise ValueError( - "Must provide y input or use kv_cache to enable streaming decoding" - ) - k = self.kv_cache.k_cache - v = self.kv_cache.v_cache - else: - # Update k and v shape, positional embeddings, and normalization - - # k has shape [b, s_y, num_kv_heads * head_dim] - # v has shape [b, s_y, num_kv_heads * head_dim] - k = self.k_proj(y) - v = self.v_proj(y) - - # Apply positional embeddings - # k: [b, s_y, n_kv, h_d] - k = k.view(b, s_y, -1, self.head_dim) - if self.pos_embeddings is not None: - k = self.pos_embeddings(k, input_pos=input_pos) - - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - - # If needed, expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Normalize k - if self.k_norm is not None: - k = self.k_norm(k) - - # Update key-value cache - if self.kv_cache is not None and self.cache_enabled: - k, v = self.kv_cache.update(k, v) - - # TODO: how to avoid to compute same block mask at every layer ? - # https://pytorch.org/blog/flexattention/#q-when-should-we-recompute-the-blockmask - block_mask = create_block_mask( - mask_mod=self.mask_mod, - B=b, - H=self.num_heads, - Q_LEN=s_x, - KV_LEN=s_x, - device=q.device, - ) - - output = compile_friendly_flex_attention_with_score_and_block( - q, k, v, score_mod=self.score_mod, block_mask=block_mask - ) - - # reshape the output to be the same shape as the input - output = output.transpose(1, 2).contiguous().view(b, s_x, -1) - return self.output_proj(output) diff --git a/torchtune/models/gemma2/_attention_utils.py b/torchtune/models/gemma2/_attention_utils.py deleted file mode 100644 index 8a17c8ec86..0000000000 --- a/torchtune/models/gemma2/_attention_utils.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any - -import torch - -# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION -# The flex attention implementation for gemma2 is not working yet -# flex attention is disabled for now untill we solve the case -_SUPPORTS_FLEX_ATTENTION = False - -if _SUPPORTS_FLEX_ATTENTION: - from functools import lru_cache - - from torch.nn.attention.flex_attention import ( - BlockMask, - create_block_mask, - flex_attention, - ) - - flex_attention_compiled = torch.compile( - flex_attention, dynamic=False, mode="max-autotune" - ) - - @lru_cache - def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"): - block_mask = create_block_mask(score_mod, b, h, m, n, device=device) - return block_mask - - # We cannot do nested compile, but flex attention only has perf benefits - # when compiled. To insulate it from the compiler, we wrap it with - # compiler.disable so that it can be used regardless of whether the model - # is compiled or not, and flex attention always remains compiled. - @torch.compiler.disable(recursive=False) - def compile_friendly_flex_attention_with_score_and_block( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - block_mask: BlockMask, - score_mod: Any, - ) -> torch.Tensor: - """ - Flex attention does not seem to work with my A6000 with the default options. - Using proposed options here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1815058279 - """ - return flex_attention( - q, - k, - v, - score_mod=score_mod, - block_mask=block_mask, - # kernel_options={ - # "BLOCK_M": 32, - # "BLOCK_N": 32, - # }, - ) - - -def flex_causal_sliding_window(sliding_window_size): - def sliding_window_causal_mask(b, h, q_idx, kv_idx): - """Causal mask and sliding window as proposed here: - https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb - """ - causal_mask = q_idx >= kv_idx - if sliding_window_size is None: - # if no sliding window return causal mask - return causal_mask - else: - windowed_mask = q_idx - kv_idx <= sliding_window_size - - return causal_mask & windowed_mask - - return sliding_window_causal_mask - - -def flex_tanh_soft_capping_with_scaling(softcapping, query_pre_attn_scalar): - def tanh_soft_capping_with_scaling(score, b, h, q_idx, kv_idx): - """ - This handle both simple tanh soft capping and custom scaling - """ - if query_pre_attn_scalar is None: - # usual scaling included in FlexAttention - # TODO: could be made faster with approximate tanh ? - # https://github.com/pytorch-labs/attention-gym/blob/f7c93ded4abf9fd8d7dc9d8bcbf57e420b891e2d/examples/flex_attn.ipynb#L733 - score = score / softcapping - score = torch.tanh(score) - return score * softcapping - else: - score = score / softcapping * query_pre_attn_scalar**-0.5 - score = torch.tanh(score) - return score * softcapping - - return tanh_soft_capping_with_scaling diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 253cea2f87..0ddef36857 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -22,27 +22,7 @@ from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp -# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION -# The flex attention implementation for gemma2 is not working yet -# flex attention is disabled for now untill we solve the case -_SUPPORTS_FLEX_ATTENTION = False -import logging -from torchtune.utils._logging import get_logger, log_once - -_log: logging.Logger = get_logger() - - -if _SUPPORTS_FLEX_ATTENTION: - from torchtune.models.gemma2._attention import FlexGemma2Attention - log_once( - _log, - "Using flex attention for Gemma2 attention computation.", - level=logging.DEBUG, - ) - _flex_or_native_gemma2_attention = FlexGemma2Attention -else: - _flex_or_native_gemma2_attention = Gemma2Attention """ Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. @@ -141,7 +121,7 @@ def gemma2( mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - self_att = _flex_or_native_gemma2_attention( + self_att = Gemma2Attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, @@ -337,7 +317,7 @@ def lora_gemma2_self_attention( use_dora: bool = False, quantize_base: bool = False, -) -> _flex_or_native_gemma2_attention: +) -> Gemma2Attention: if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" @@ -413,7 +393,7 @@ def lora_gemma2_self_attention( rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_att = _flex_or_native_gemma2_attention( + self_att = Gemma2Attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, From 5127db927bbbb8b87dcef5281c71433b8938a1ca Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 31 Oct 2024 18:40:25 +0000 Subject: [PATCH 08/11] Update docs/source/api_ref_models.rst Co-authored-by: ebsmothers --- docs/source/api_ref_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index 8bd039805a..cb76c3eb5d 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -339,7 +339,7 @@ To download the Gemma2 2B, 9B, 27B models : :nosignatures: gemma2.gemma2 - gemma2.lora_gemma + gemma2.lora_gemma2 gemma2.gemma2_2b gemma2.lora_gemma2_2b gemma2.qlora_gemma2_2b From 599c828e142b9581acd933660df2ae628fbbef2b Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 31 Oct 2024 19:55:02 +0000 Subject: [PATCH 09/11] Update docs/source/api_ref_models.rst --- docs/source/api_ref_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index cb76c3eb5d..ebbc12a932 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -321,7 +321,7 @@ To download the Gemma 7B model: gemma2 : ------ +-------- Models of size 2B, 9B, 27B from the `Gemma family `_. From 7adb55a4da1b07bdc1dc56aff79d6446a96aee9e Mon Sep 17 00:00:00 2001 From: Optimox Date: Mon, 4 Nov 2024 13:38:14 +0100 Subject: [PATCH 10/11] fix: update mask for causal sliding window attention --- docs/source/api_ref_models.rst | 2 +- torchtune/models/gemma2/_attention.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index 8bd039805a..cb76c3eb5d 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -339,7 +339,7 @@ To download the Gemma2 2B, 9B, 27B models : :nosignatures: gemma2.gemma2 - gemma2.lora_gemma + gemma2.lora_gemma2 gemma2.gemma2_2b gemma2.lora_gemma2_2b gemma2.qlora_gemma2_2b diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index b944aae1f8..65683e4ab3 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -285,7 +285,9 @@ def forward( k, v = self.kv_cache.update(k, v) q.mul_(self.scaling) - output = torch.matmul(q, k.transpose(2, 3)) + output = torch.matmul( + q, k.transpose(2, 3) + ) # [batch_size, n_local_heads, input_len, head_dim] # if mask is None: default to causal mask if mask is None: @@ -296,6 +298,12 @@ def forward( ).to(x.device) ) + # update masks bias to be 0 for visible tokens and -2.3819763e38 otherwise + # this is similar to what torch sdpa is doing: + # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + if mask.dtype == torch.bool: + mask = torch.where(mask.logical_not(), -2.3819763e38, 0) + if self.sliding_window_size is not None: all_ones = torch.ones_like(mask) @@ -304,6 +312,11 @@ def forward( ) * torch.tril(all_ones, self.sliding_window_size - 1) mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) + if mask.dim() == 3: + # This is the case for block masks where attention is different per sample + # we want mask to be broadcastable with output so we aim for (bs, 1, s_x, s_y) + mask = mask.unsqueeze(1) + if self.softcapping is not None: output = output / self.softcapping output = torch.tanh(output) From 53eed4019fa1d41997df7b2af90fa86defc90aa9 Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 8 Nov 2024 16:04:20 +0100 Subject: [PATCH 11/11] add error for block masks --- recipes/configs/gemma2/27B_full.yaml | 4 ++-- recipes/configs/gemma2/27B_lora.yaml | 2 +- recipes/configs/gemma2/27B_lora_single_device.yaml | 2 +- recipes/configs/gemma2/27B_qlora_single_device.yaml | 4 ++-- recipes/configs/gemma2/9B_full.yaml | 4 ++-- recipes/configs/gemma2/9B_lora.yaml | 2 +- recipes/configs/gemma2/9B_lora_single_device.yaml | 2 +- recipes/configs/gemma2/9B_qlora_single_device.yaml | 2 +- torchtune/models/gemma2/_attention.py | 6 ++++++ 9 files changed, 17 insertions(+), 11 deletions(-) diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index dee049024a..ddc89b38b2 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -30,14 +30,14 @@ shuffle: True # Model Arguments model: - _component_: torchtune.models.gemma2.gemma_27b + _component_: torchtune.models.gemma2.gemma2_27b checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00024 + max_filename: "00024" recipe_checkpoint: null output_dir: /tmp/gemma-2-27b model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml index 265895090d..a138441199 100644 --- a/recipes/configs/gemma2/27B_lora.yaml +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -42,7 +42,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00024 + max_filename: "00024" recipe_checkpoint: null output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index e245aafa92..577b0715c5 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -41,7 +41,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00024 + max_filename: "00024" recipe_checkpoint: null output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml index 2f0e7d6cad..14d9b75ba7 100644 --- a/recipes/configs/gemma2/27B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -29,7 +29,7 @@ shuffle: True # Model Arguments model: - _component_: torchtune.models.gemma2.qlora_gemma_27b + _component_: torchtune.models.gemma2.qlora_gemma2_27b lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 @@ -41,7 +41,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-27b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00024 + max_filename: "00024" recipe_checkpoint: null output_dir: /tmp/gemma-2-27b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index 0002b1c3b9..0fc7e6e4e4 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -30,14 +30,14 @@ shuffle: True # Model Arguments model: - _component_: torchtune.models.gemma2.gemma_9b + _component_: torchtune.models.gemma2.gemma2_9b checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00008 + max_filename: "00008" recipe_checkpoint: null output_dir: /tmp/gemma-2-9b model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml index 5b0141e9ef..960e4fa881 100644 --- a/recipes/configs/gemma2/9B_lora.yaml +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -42,7 +42,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00008 + max_filename: "00008" recipe_checkpoint: null output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml index 197ee121ae..e9d6c22a73 100644 --- a/recipes/configs/gemma2/9B_lora_single_device.yaml +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -41,7 +41,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00008 + max_filename: "00008" recipe_checkpoint: null output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml index 80a3303104..8991ba9ece 100644 --- a/recipes/configs/gemma2/9B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -41,7 +41,7 @@ checkpointer: checkpoint_dir: /tmp/gemma-2-9b/ checkpoint_files: filename_format: model-{}-of-{}.safetensors - max_filename: 00008 + max_filename: "00008" recipe_checkpoint: null output_dir: /tmp/gemma-2-9b/ model_type: GEMMA2 diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index 65683e4ab3..b00612d032 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -210,6 +210,12 @@ def forward( - d: embed dim - h_d: head dim """ + # until flex attention implementation exists, we do not accept block masks + if (mask is not None) and (type(mask) != torch.Tensor()): + raise NotImplementedError( + "Block masks are not implemeted yet, use packed=False" + ) + # x has shape [b, s_x, d] # y has shape [b, s_y, d] b, s_x, _ = x.shape