Skip to content

Commit

Permalink
Merge branch 'main' into code_llama2_evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ReemaAlzaid authored Dec 26, 2024
2 parents 8495571 + aa8f365 commit 26126c6
Show file tree
Hide file tree
Showing 19 changed files with 584 additions and 433 deletions.
300 changes: 166 additions & 134 deletions docs/source/deep_dives/checkpointer.rst

Large diffs are not rendered by default.

601 changes: 332 additions & 269 deletions docs/source/tutorials/e2e_flow.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: 1024
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ model:
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: 1024
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand Down
9 changes: 4 additions & 5 deletions recipes/configs/llama3/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -108,8 +112,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ model:
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand Down
9 changes: 4 additions & 5 deletions recipes/configs/llama3_1/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -111,8 +115,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
9 changes: 4 additions & 5 deletions recipes/configs/llama3_2/1B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -107,8 +111,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
9 changes: 4 additions & 5 deletions recipes/configs/llama3_2/3B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -108,8 +112,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
42 changes: 42 additions & 0 deletions recipes/configs/llama3_2/evaluation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Config for EleutherEvalRecipe in eleuther_eval.py
#
# To launch, run the following command:
# tune run eleuther_eval --config llama3_2/evaluation

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B-Instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
max_seq_len: null

# Environment
device: cpu
dtype: bf16
seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed

# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 8
enable_kv_cache: True

# Quantization specific args
quantizer: null
6 changes: 6 additions & 0 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class QATRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
ValueError: If ``compile`` is set to True.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
Expand All @@ -133,6 +134,11 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if cfg.get("compile", False):
raise ValueError(
"Compile is not yet supported for QAT. Please set compile=False."
)

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down
8 changes: 7 additions & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
ValueError: If world_size is 1
ValueError: If world_size is 1.
ValueError: If ``compile`` is set to True.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
Expand All @@ -149,6 +150,11 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if cfg.get("compile", False):
raise ValueError(
"Compile is not yet supported for QAT. Please set compile=False."
)

_, rank = utils.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
Expand Down
1 change: 1 addition & 0 deletions tests/torchtune/modules/_export/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def test_attention_export(self):
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
)
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_tile_positional_embedding_smoke(self):
torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
)
def test_tile_positional_embedding_export(self):

tpe_ep = torch.export.export(
self.tpe,
(self.x, self.aspect_ratio),
dynamic_shapes=(
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

y = tpe_ep.module()(self.x, self.aspect_ratio)
Expand Down Expand Up @@ -129,14 +129,14 @@ def test_tiled_token_positional_embedding_smoke(self):
torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
)
def test_tiled_token_positional_embedding_export(self):

tpe_ep = torch.export.export(
self.tpe,
(self.x, self.aspect_ratio),
dynamic_shapes=(
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

y = tpe_ep.module()(self.x, self.aspect_ratio)
Expand All @@ -155,6 +155,7 @@ def test_tiled_token_positional_embedding_aoti(self):
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
4 changes: 4 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,10 @@ class Recipe:
name="mistral/evaluation",
file_path="mistral/evaluation.yaml",
),
Config(
name="llama3_2/evaluation",
file_path="llama3_2/evaluation.yaml",
),
Config(
name="code_llama2/evaluation",
file_path="code_llama2/evaluation.yaml",
Expand Down
2 changes: 1 addition & 1 deletion torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_causal_mask_from_padding_mask(
- [bsz, seq_length, target_seq_len] if ``target_seq_len`` was specified.
Raises:
AssertionError: if ``target_seq_len > seq_len``, the sequence length of the padding mask.
AssertionError: if ``target_seq_len < seq_len``, the sequence length of the padding mask.
Example:
>>> padding_mask = torch.tensor([[False, True, True, True]])
Expand Down
2 changes: 1 addition & 1 deletion torchtune/training/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def log_config(self, config: DictConfig) -> None:
try:
output_config_fname = Path(
os.path.join(
config.checkpointer.checkpoint_dir,
config.output_dir,
"torchtune_config.yaml",
)
)
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5.0
0.6.0

0 comments on commit 26126c6

Please sign in to comment.