diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml index 166fbeac1d..19e17c362a 100644 --- a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -124,10 +124,10 @@ shuffle: True device: cuda # Training arguments -batch_size: 64 +batch_size: 128 num_steps: 10000 -ppo_epochs: 2 -ppo_batch_size: 32 +ppo_epochs: 1 +ppo_batch_size: 128 gradient_accumulation_steps: 1 # Use to increase effective batch size # Memory management and performance @@ -137,13 +137,14 @@ optimizer: lr: 3e-6 optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 log_peak_memory_stats: True -enable_activation_checkpointing: True # True reduces memory +enable_activation_checkpointing: True # True reduces memory +enable_kv_cache: True # Reduced precision dtype: bf16 # batch size for forward pass during generation -forward_batch_size: 16 +forward_batch_size: 128 max_generated_tokens: 58 temperature: 0.7 top_k: null @@ -180,3 +181,27 @@ metric_logger: log_dir: ${output_dir}/logs log_every_n_steps: 1 + +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: False + record_shapes: False + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 3 + num_cycles: 1 diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py index 1030217d74..fb5e0559aa 100644 --- a/recipes/ppo_full_finetune_single_device.py +++ b/recipes/ppo_full_finetune_single_device.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import math -import os import sys +import time from functools import partial from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch @@ -20,11 +21,16 @@ from torchtune import config, generation, modules, rlhf, training, utils from torchtune.data import padded_collate from torchtune.datasets import ConcatDataset +from torchtune.modules import local_kv_cache from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.rlhf import PPOStats, Trajectory +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm log = utils.get_logger("DEBUG") +# enabling compile results in slightly more recompiles than the default cache limit (8) +# so we set a higher limit here +torch._dynamo.config.cache_size_limit = 16 class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): @@ -32,8 +38,8 @@ class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): Full finetuning recipe for RLHF with PPO for dense transformer-based LLMs such as LLama2. This recipe is optimized for single GPU training. Training on CPU is not supported. - This implementation is based on `Learning to summarize from human feedback ). Features: - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` @@ -175,8 +181,9 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model - self._model_compile = cfg.compile + self.compile = cfg.compile self._optimizer_in_bwd = cfg.optimizer_in_bwd + ( self._policy_model, self._value_model, @@ -186,7 +193,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_model=cfg.policy_model, cfg_reward_value_model=cfg.reward_and_value_model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=self._model_compile, + compile_model=self.compile, policy_state_dict=policy_model_checkpoint_dict[training.MODEL_KEY], ref_policy_state_dict=ref_policy_state_dict[training.MODEL_KEY], value_model_state_dict=value_model_checkpoint_dict[training.MODEL_KEY], @@ -213,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None: log.info("Loss is initialized.") # sampler and dataloader depends on the tokenizer and should be set - # setup afterit is initialized + # setup after it is initialized self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, @@ -223,6 +230,21 @@ def setup(self, cfg: DictConfig) -> None: self._setup_training_parameters(cfg) self._setup_training_hyperparameters(cfg) + # setup a context manager for enabling KV-cacheing during + # trajectory generation if enabled in the config + self.cache_ctx_manager = lambda enable_kv_cache: ( + local_kv_cache( + self._policy_model, + batch_size=self._forward_batch_size, + dtype=self._dtype, + decoder_max_seq_len=self._tokenizer.max_seq_len + + self._max_generated_tokens, + device=self._device, + ) + if enable_kv_cache + else contextlib.nullcontext() + ) + if self._resume_from_checkpoint: self._update_recipe_state(policy_model_checkpoint_dict) @@ -233,6 +255,77 @@ def setup(self, cfg: DictConfig) -> None: * (self.batch_size // self._ppo_batch_size) ) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_training_hyperparameters(self, cfg) -> None: """ Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, @@ -295,6 +388,7 @@ def _setup_training_parameters(self, cfg: DictConfig) -> None: self._ppo_backward_batch_size = ( cfg.ppo_batch_size // self._gradient_accumulation_steps ) + self.enable_kv_cache = cfg.enable_kv_cache if self.batch_size % self._forward_batch_size != 0: raise ValueError( @@ -423,6 +517,12 @@ def _setup_models( reward_model = config.instantiate(cfg_reward_value_model) value_model = config.instantiate(cfg_reward_value_model) + if compile_model: + training.compile_model(policy_model) + training.compile_model(ref_policy_model) + training.compile_model(value_model) + training.compile_model(reward_model) + if enable_activation_checkpointing: training.set_activation_checkpointing( policy_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} @@ -450,6 +550,7 @@ def _setup_models( value_model.load_state_dict(value_model_state_dict) # Validate models were loaded in with the expected dtype. + training.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) @@ -490,16 +591,6 @@ def _setup_models( for p in ref_policy_model.parameters(): p.requires_grad = False - # Compile model, if enabled. - if compile_model: - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - log.info("Compiling models with torch.compile...") - - policy_model.compile(backend=backend) - reward_model.compile(backend=backend) - ref_policy_model.compile(backend=backend) - value_model.compile(backend=backend) - if self._device.type == "cuda": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -585,7 +676,6 @@ def _setup_data( dataset=ds, sampler=sampler, batch_size=batch_size, - # dropping last avoids shape issues with compile + flex attention drop_last=True, collate_fn=partial( padded_collate, @@ -688,19 +778,19 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising the current trajectory. """ - batch_size, context_length = input_ids.shape # step 1: generate responses, and logits corresponding to the responses using the current policy - query_responses, logits = generation.generate( - model=self._policy_model, - prompt=input_ids, - max_generated_tokens=self._max_generated_tokens, - temperature=self._temperature, - top_k=self._top_k, - pad_id=self._tokenizer.pad_id, - rng=self._rng, - ) - + with self.cache_ctx_manager(self.enable_kv_cache): + query_responses, logits = generation.generate( + model=self._policy_model, + prompt=input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + ) + _, context_length = input_ids.shape responses = query_responses[:, context_length:].clone() query_response_padding_masks = query_responses != self._tokenizer.pad_id @@ -715,7 +805,6 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: del query_response_padding_masks # step 2. estimate logprobs of the responses using the current policy - logits = logits[:, context_length - 1 :] logprobs = rlhf.logits_to_logprobs(logits, responses, self._temperature) del logits @@ -751,7 +840,9 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: # step 5.1 the scores from the reward model are the logits for the last non-padding token in # each (query, truncated-response) pair seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) - scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) + scores = scores.gather(1, (seq_lens + context_length)[:, None, None]).squeeze( + (-1, -2) + ) # step 5.2 if configured, apply any penalties for sequences without EOS tokens # or shorter than a certain length @@ -775,11 +866,9 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: seq_lens, ) value_padding_masks = response_padding_masks.clone() - value_padding_masks[ - torch.arange(batch_size, device=value_padding_masks.device), - value_seq_idxs, - ] = False - + value_padding_masks = value_padding_masks.scatter_( + 1, value_seq_idxs.unsqueeze(-1), False + ) values[value_padding_masks] = 0.0 return Trajectory( @@ -798,8 +887,8 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: """ - Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. - See ``generate_trajectory`` for more details. + Generates a self.batch_size batch of trajectories using self._forward_batch_size batch sizes. + See generate_trajectory for more details. Args: input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] @@ -814,6 +903,7 @@ def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: batch_input_ids = input_ids[ batch_start : batch_start + self._forward_batch_size ] + trajectories.append(self.generate_trajectory(batch_input_ids)) return Trajectory(*map(torch.cat, zip(*trajectories))) @@ -821,7 +911,7 @@ def train(self) -> None: """ The core training loop.""" - if self._model_compile: + if self.compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward." "Expect a relatively slow first iteration." @@ -831,25 +921,33 @@ def train(self) -> None: self._optimizer.zero_grad() training_completed = False + self._profiler.start() pbar = tqdm(total=self._total_steps, initial=self._steps_run) for curr_epoch in range(self._epochs_run, self._total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) - for _, batch in enumerate(self._dataloader): + for idx, batch in enumerate(self._dataloader): + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + batch = batch["tokens"].to(self._device) _, context_length = batch.shape + num_tokens = batch.numel() - # step 1. generate the trajectory using: - # - the current policy (pi_theta) - # - the current value function (V_phi) - # - the reference frozen policy model (pi_theta_0) + # step 1. generate the trajectory + t0_traj = time.perf_counter() trajectory = self.generate_trajectory_batched(batch) + traj_time = time.perf_counter() - t0_traj - # step 2. get the rewards for the current trajectory. these are based on: - # - the divergence between the current policy and the reference policy - # - the scores from the reward model + # step 2. get the rewards for the current trajectory rewards, kl, kl_rewards = rlhf.get_rewards_ppo( trajectory.scores, trajectory.logprobs, @@ -867,7 +965,8 @@ def train(self) -> None: masks=~trajectory.response_padding_masks, ) - # step 4. optimise using the PPO objective over multiple epochs + # # step 4. optimise using the PPO objective over multiple epochs + t0_ppo = time.perf_counter() ppo_stats: List[PPOStats] = [] for _ in range(self._ppo_epochs): batch_idxs = torch.randperm(self.batch_size, device=self._device) @@ -893,7 +992,7 @@ def train(self) -> None: ) ) batch_ppo_stats.append( - self._ppo_step( + self.ppo_step( batch_trajectory, advantages[backward_batch_idxs], returns[backward_batch_idxs], @@ -909,6 +1008,7 @@ def train(self) -> None: self._optimizer.zero_grad(set_to_none=True) self.global_step += 1 + ppo_time = time.perf_counter() - t0_ppo # step 5. profit self._steps_run += 1 @@ -918,11 +1018,28 @@ def train(self) -> None: PPOStats(*map(torch.stack, zip(*ppo_stats))), kl, kl_rewards, + num_tokens / traj_time, + num_tokens / ppo_time, ) self.cleanup_after_step( trajectory, ppo_stats, advantages, returns, kl, kl_rewards ) pbar.update(1) + + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + self._profiler.step() + if self._steps_run == self._total_steps: training_completed = True break @@ -934,9 +1051,12 @@ def train(self) -> None: curr_epoch, is_intermediate_checkpoint=not training_completed ) if training_completed: + self._profiler.stop() return - def _ppo_step( + self._profiler.stop() + + def ppo_step( self, trajectory: Trajectory, advantages: torch.Tensor, @@ -1023,6 +1143,8 @@ def log_metrics( ppo_stats: PPOStats, kl: torch.Tensor, kl_rewards: torch.Tensor, + tokens_per_second_trajectory: torch.Tensor, + tokens_per_second_loss: torch.Tensor, ) -> None: """ Log metrics and statistics for the current step to the metric logger. @@ -1040,6 +1162,8 @@ def log_metrics( "ratios": ppo_stats.ratios.mean(), "approx_policy_kl": ppo_stats.approx_policy_kls.mean(), "response_lengths": trajectory.seq_lens.float().mean(), + "tokens_per_second_per_gpu_trajectory": tokens_per_second_trajectory, + "tokens_per_second_per_gpu_ppo": tokens_per_second_loss, } if self._device.type == "cuda" and self._log_peak_memory_stats: log_dict.update(training.get_memory_stats(device=self._device)) diff --git a/tests/torchtune/generation/test_generation.py b/tests/torchtune/generation/test_generation.py index 4efd1e3acd..b740e7afbf 100644 --- a/tests/torchtune/generation/test_generation.py +++ b/tests/torchtune/generation/test_generation.py @@ -245,7 +245,7 @@ def test_reproducibility(self, request, model1, model2, prompt_tokens): top_k = 100 torch.manual_seed(42) - outputs_first, _ = generate( + outputs_first, logits_first = generate( model=model1, prompt=prompt_tokens, max_generated_tokens=10, @@ -254,17 +254,15 @@ def test_reproducibility(self, request, model1, model2, prompt_tokens): ) torch.manual_seed(42) - outputs_second, _ = generate( + outputs_second, logits_second = generate( model=model2, prompt=prompt_tokens, max_generated_tokens=10, temperature=temperature, top_k=top_k, ) - - # slicing for the last 18 tokens - this is the whole sequence for unpadded inputs - # and excludes the first two tokens for padded inputs, which are padding tokens assert torch.equal(outputs_first, outputs_second) + torch.testing.assert_close(logits_first, logits_second) @pytest.mark.parametrize( "model1", @@ -303,7 +301,7 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 top_k = 100 torch.manual_seed(42) - outputs_first, _ = generate( + outputs_first, logits_first = generate( model=model1, prompt=prompt1, max_generated_tokens=10, @@ -312,7 +310,7 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 ) torch.manual_seed(42) - outputs_second, _ = generate( + outputs_second, logits_second = generate( model=model2, prompt=prompt2, max_generated_tokens=10, @@ -323,6 +321,8 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2 # slicing for the last 18 tokens - this is the whole sequence for unpadded inputs # and excludes the first two tokens for padded inputs, which are padding tokens assert torch.equal(outputs_first[:, -18:], outputs_second[:, -18:]) + # logits are only ever returned for the generated tokens, so no slicing needed + torch.testing.assert_close(logits_first, logits_second, atol=1e-4, rtol=1e-6) @pytest.mark.parametrize( "model", @@ -343,7 +343,8 @@ def test_stop_tokens_batched(self, request, model, prompt, expected_tokens_batch top_k = 100 # This is the first token generated by the model - # so it should stop immediately + # so it should stop immediately resulting in only a single + # token being generated stop_tokens = [3987, 3958, 3989] torch.manual_seed(42) @@ -465,7 +466,6 @@ def test_stop_tokens_batched_uneven_stopping_left_padded( [0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3989, 0, 0], ] ) - assert torch.equal(outputs, expected_output) diff --git a/tests/torchtune/rlhf/test_rewards.py b/tests/torchtune/rlhf/test_rewards.py index 4284d1d63d..ecc5422cf5 100644 --- a/tests/torchtune/rlhf/test_rewards.py +++ b/tests/torchtune/rlhf/test_rewards.py @@ -9,7 +9,7 @@ class TestGetRewards: - def test_get_rewards(self): + def test_get_rewards_ppo(self): scores = torch.tensor([1.0, 2.0, 3.0]) logprobs = torch.tensor( [ @@ -25,7 +25,7 @@ def test_get_rewards(self): [0.9, 1.0, 1.1], ] ) - kl_controller_value = 0.5 + kl_coeff = 0.5 # expected kl is logprobs - ref_logprobs expected_kl = torch.tensor( @@ -36,7 +36,7 @@ def test_get_rewards(self): ] ) - # expected kl_rewards is -kl_controller_value * kl + # expected kl_rewards is -kl_coeff * kl expected_kl_rewards = torch.tensor( [ [0.05, 0.05, 0.05], @@ -55,7 +55,22 @@ def test_get_rewards(self): ) rewards, kl, kl_rewards = rlhf.get_rewards_ppo( - scores, logprobs, ref_logprobs, kl_controller_value + scores, logprobs, ref_logprobs, kl_coeff + ) + + torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + kl_rewards, expected_kl_rewards, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(rewards, expected_rewards, rtol=1e-4, atol=1e-4) + + # add a test to ensure valid_score_idxs works as expected + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + scores, + logprobs, + ref_logprobs, + kl_coeff, + valid_score_idxs=torch.tensor([2, 2, 2]), ) torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) @@ -137,7 +152,7 @@ def test_masked_var(self): mask = torch.tensor([True, True, True, False, False]) expected_var = torch.tensor(1.0) - output = rlhf.masked_var(x, mask) + output = rlhf.masked_var(x - rlhf.masked_mean(x, mask), mask) torch.testing.assert_close(output, expected_var, rtol=1e-4, atol=1e-4) diff --git a/tests/torchtune/training/test_pooling.py b/tests/torchtune/training/test_pooling.py index bb1204b6bf..9f2bd956d7 100644 --- a/tests/torchtune/training/test_pooling.py +++ b/tests/torchtune/training/test_pooling.py @@ -7,7 +7,7 @@ from torchtune.training.pooling import get_unmasked_sequence_lengths -class TestGetLastUnmaskedTokenIdx: +class TestGetUnmaskedSeqenceLengths: def test_get_last_unmasked_token_idx_multi_batch(self): """ Tests that the last non-padding tokens are correctly selected for a multi-batch input. diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index bb4b1ff0b0..eb550e6e19 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -94,15 +94,15 @@ def generate_next_token( - tokens (torch.Tensor): tensor with the generated tokens, with shape [bsz x 1]. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, - with shape [bsz x seq_length x vocab_size]. + with shape [bsz x 1 x vocab_size]. """ # model produces logits in [bsz, seq_length, vocab_size] # we want to take the last token's logits as the input to the next model call - logits = model(x, input_pos=input_pos, mask=mask) + logits = model(x, input_pos=input_pos, mask=mask)[:, -1] return ( - sample(logits[:, -1].clone(), temperature=temperature, top_k=top_k, q=q), - logits, + sample(logits.clone(), temperature=temperature, top_k=top_k, q=q), + logits.unsqueeze(1), ) @@ -189,7 +189,7 @@ def get_position_ids_from_padding_mask( return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int) -@torch.inference_mode() +@torch.no_grad() def generate( model: TransformerDecoder, prompt: torch.Tensor, @@ -241,7 +241,7 @@ def generate( with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens`` may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, - with shape ``[bsz x seq_len + num_generated_tokens x vocab_size]``. + with shape ``[bsz x num_generated_tokens x vocab_size]``. """ prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt @@ -355,8 +355,8 @@ def generate( # if incremental decoding is enabled, we can use the current position # otherwise, we take the whole sequence up to the current position if incremental_decoding: - curr_input_pos = input_pos[:, curr_pos] - curr_masks = masks[:, curr_pos, None, :] + curr_input_pos = input_pos[:, curr_pos].contiguous() + curr_masks = masks[:, curr_pos, None, :].contiguous() else: tokens = generated_tokens.clone() curr_input_pos = input_pos[:, : curr_pos + 1] @@ -377,11 +377,8 @@ def generate( q=q, ) generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + generated_logits = torch.cat([generated_logits, logits], dim=1) curr_pos += 1 - if incremental_decoding: - generated_logits = torch.cat([generated_logits, logits], dim=1) - else: - generated_logits = logits if stop_tokens is not None: stop_token_reached = update_stop_tokens_tracker( @@ -393,6 +390,6 @@ def generate( # mask out generated tokens in seqs that already hit a stop token if stop_tokens is not None: generated_tokens *= stop_token_mask - generated_logits *= stop_token_mask[:, :-1, None] + generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None] return generated_tokens, generated_logits diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index e96491c22a..e01c9ca10a 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -109,6 +109,6 @@ def update( # this allows us to track the current position in the cache # after the last update in a compile-friendly way without any dynamism # e.g. relying on an int size tracker, or re-creating cache_pos every time - self.cache_pos += seq_len + self.cache_pos.add_(seq_len) return k_out, v_out diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 66ac92002f..5f739d845b 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -418,12 +418,11 @@ def setup_caches( encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ - has_encoder_layers = any( isinstance(m, TransformerCrossAttentionLayer) for m in self.modules() ) has_decoder_layers = any( - isinstance(l, TransformerSelfAttentionLayer) for l in self.layers + isinstance(m, TransformerSelfAttentionLayer) for m in self.modules() ) if has_encoder_layers: @@ -437,7 +436,6 @@ def setup_caches( self.decoder_max_cache_seq_len = decoder_max_seq_len else: self.decoder_max_cache_seq_len = self.max_seq_len - for layer in self.layers: layer.setup_caches( batch_size, diff --git a/torchtune/rlhf/loss/ppo.py b/torchtune/rlhf/loss/ppo.py index d4770802f7..2c3d48a0e8 100644 --- a/torchtune/rlhf/loss/ppo.py +++ b/torchtune/rlhf/loss/ppo.py @@ -82,7 +82,9 @@ def forward( policy_losses_clipped = -advantages * clipped_ratios policy_losses_unclipped = -advantages * ratios - clipfrac = (policy_losses_clipped > policy_losses_unclipped).float() + clipfrac = (policy_losses_clipped > policy_losses_unclipped).to( + pi_logprobs.dtype + ) clipfrac = ( clipfrac.mean() if padding_masks is None diff --git a/torchtune/rlhf/rewards.py b/torchtune/rlhf/rewards.py index f0e42ca58c..f5882908bc 100644 --- a/torchtune/rlhf/rewards.py +++ b/torchtune/rlhf/rewards.py @@ -76,10 +76,6 @@ def get_rewards_ppo( - response_len: model response length """ - # 1. calculate kl between logprobs and reflogprobs - # 2. calculate kl reward using adaptive scaling value - # 3. calculate total reward by summing above - # return all kl = logprobs - ref_logprobs kl_reward = -kl_coeff * kl @@ -89,9 +85,9 @@ def get_rewards_ppo( # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L153 if valid_score_idxs is not None: - total_reward[ - torch.arange(scores.shape[0], device=scores.device), valid_score_idxs - ] += scores + total_reward.scatter_add_( + 1, valid_score_idxs.unsqueeze(-1), scores.unsqueeze(-1) + ) else: total_reward[:, -1] += scores @@ -113,17 +109,17 @@ def masked_mean( Returns: torch.Tensor: The mean tensor. """ - return (x * mask).sum(dim=dim) / mask.sum(dim=dim) + return (x * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8) def masked_var( - x: torch.Tensor, mask: torch.Tensor, unbiased: bool = True + centered_values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True ) -> torch.Tensor: """ - Compute variance of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py - + Compute variance of mean-centered tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + We use ``centered_values`` to avoid repeated calls to ``masked_mean``. Args: - x (torch.Tensor): The input tensor. + centered_values (torch.Tensor): The mean-centered tensor e.g. ``x - masked_mean(x)``. mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` should participate in the mean calculation. unbiased (bool): Whether to use the unbiased variance. @@ -131,21 +127,10 @@ def masked_var( Returns: torch.Tensor: The variance tensor. - Raises: - ValueError: If the sum of the mask is zero. """ - mean = masked_mean(x, mask) - centered_values = x - mean var = masked_mean(centered_values.pow(2), mask) if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError( - "The sum of the mask is zero, which can happen when ``ppo_batch_size=1``;" - "try increase the ``ppo_batch_size`` or ``gradient_accumulation_steps``" - ) - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size + mask_sum = mask.sum() + 1e-8 bessel_correction = mask_sum / (mask_sum - 1) var = var * bessel_correction return var @@ -158,16 +143,16 @@ def whiten( Whiten (normalises) values, optionally with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py Args: x (torch.Tensor): The input tensor. - mask (Optional[torch.Tensor]): The bool mask tensor, where True indicates the corresponding value in ``x`` - should participate in the mean calculation. Default None. - shift_mean (bool): Whether to shift normalised values by the mean. + mask (Optional[torch.Tensor]): The bool mask tensor with the same shape as ``x``, and where True indicates + the corresponding value in ``x`` should participate in the mean calculation. Default None. + shift_mean (bool): Whether to shift normalised values by the mean. Default True. Returns: torch.Tensor: The whitened tensor. """ if mask is not None: mean = masked_mean(x, mask) - var = masked_var(x, mask) if mask.any() else x.var() + var = masked_var(x - mean, mask) else: mean, var = x.mean(), x.var() whitened = (x - mean) * torch.rsqrt(var + 1e-8) @@ -228,10 +213,8 @@ def estimate_advantages( returns = advantages + values # normalize advantages across the batch of trajectories to reduce variance + advantages = whiten(advantages, mask=masks) if masks is not None: - advantages = whiten(advantages, mask=masks) advantages[~masks] = 0.0 - else: - advantages = whiten(advantages) return advantages, returns diff --git a/torchtune/training/pooling.py b/torchtune/training/pooling.py index 3e0ba41507..e0fe204a5b 100644 --- a/torchtune/training/pooling.py +++ b/torchtune/training/pooling.py @@ -8,7 +8,7 @@ def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: """ - Returns the sequence lengths for each batch element, excluding masked tokens. + Returns the sequence lengths (0-indexed) for each batch element, excluding masked tokens. Args: mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out @@ -37,13 +37,6 @@ def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: """ # calculate per-batch-element sequence lengths by finding last valid tokens - if mask.any(): - sequence_lengths = ( - (~mask).sum(-1).sub(1).clip(0).to(mask.device, dtype=torch.long) - ) - else: - sequence_lengths = torch.full( - (mask.shape[0],), mask.shape[1] - 1, dtype=torch.long, device=mask.device - ) - - return sequence_lengths + sequence_lengths = (~mask).cumsum(dim=-1).argmax(dim=-1).to(dtype=torch.long) + + return sequence_lengths.clip(0, mask.shape[1] - 1)