From 642635fbd8cc6bb3a081670636532ca796d4e014 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Fri, 18 Oct 2024 18:38:32 +0000 Subject: [PATCH 1/2] added HiddenStatesExtractor which can be used with forward and generate from fms directly to extract hidden states; removed unnecessary classes --- speculator/train_speculator.py | 16 ++- speculator/train_speculator_utils.py | 175 ++++++--------------------- 2 files changed, 51 insertions(+), 140 deletions(-) diff --git a/speculator/train_speculator.py b/speculator/train_speculator.py index 7ef5e8f9..68bd5073 100644 --- a/speculator/train_speculator.py +++ b/speculator/train_speculator.py @@ -6,7 +6,9 @@ import torch import torch.optim as optim from fms.models import get_model +from fms.models.gpt_bigcode import GPTBigCode from fms.models.llama import LLaMABlock +from fms.models.mixtral import Mixtral from fms.utils import generation, tokenizers from fms_extras.models.speculator import MLPSpeculator # type: ignore from torch import distributed as dist @@ -24,8 +26,7 @@ setup, setup_environ_flags, ) -from speculator.train_speculator_utils import train_speculator - +from speculator.train_speculator_utils import train_speculator, HiddenStatesExtractor os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -159,6 +160,17 @@ def main(**kwargs): ), ) + if isinstance(model, LLaMA): + headless_model = model._helper + head = model.shared.head + elif isinstance(model, (GPTBigCode, Mixtral)): + headless_model = model.base_model + head = model.head + else: + raise ValueError("speculative training currently only supports LLaMA, GPTBigCode, and Mixtral architectures") + + model = HiddenStatesExtractor(headless_model, head) + if rank == 0: print(f"{time.time()}") print(model.config) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 0a265a63..3e807670 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -117,6 +117,38 @@ def generate( return result +class HiddenStatesExtractor(nn.Module): + + def __init__(self, headless_model: Unionp[nn.Module, Callable], head: nn.Linear): + self.headless_model = headless_model + self.head = head + self.hidden_states_output = None + + # making the assumption this is used with generate, which passes kwargs + def forward(self, *args, **kwargs): + # reset on prefill + if kwargs.get("past_key_value_states", None): + self.hidden_states_output = None + + hidden_states, cache = self.headless_model(*args, **kwargs) + + if kwargs.get("only_last_token", None): + hidden_states = hidden_states[:, -1, :] + + if self.hidden_states_output is None: + self.hidden_states_output = hidden_states + else: + self.hidden_states_output = torch.cat((self.hidden_states_output, hidden_states), dim=-2) + + logits = self.head(hidden_states) + + if kwargs.get("use_cache", None): + return logits, cache + else: + return logits + + + # Stage 1 training def stage1_loss( @@ -150,11 +182,12 @@ def stage1_loss( Returns: scalar loss value, updated ddp stats, number of tokens in input """ with torch.no_grad(): - _, embeds = model( + _ = model( base_model_input[:, : -speculator.n_predict - 1], include_embeds=True, use_cache=False, ) + embeds = model.hidden_states_output if cfg.sharding_strategy == "tp": embeds = embeds.chunk(base_model_mesh["tp"].size())[ base_model_mesh["tp"].get_local_rank() @@ -211,15 +244,16 @@ def stage2_loss( base_model_input = base_model_input[ :, : cfg.stage2_prompt_length * grow_factor ].reshape(base_model_input.size(0) * grow_factor, cfg.stage2_prompt_length) - targs, embeds = generate( + + targs = generate( model, base_model_input, cfg.seq_length, cfg.stage2_seq_length, do_sample=True, use_cache=True, - include_embeds=True, ) + embeds = model.hidden_states_output if cfg.sharding_strategy == "tp": targs = targs.chunk(base_model_mesh["tp"].size())[ @@ -431,138 +465,3 @@ def train_speculator( tokens_seen=elapsed_tokens + n_tok, is_compiled=cfg.use_torch_compile, ) - - -class EmbedGPTBigCode(GPTBigCode): - # Overrides the forward function of GPTBigCode to allow returning embedding vectors - def forward( - self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - use_cache: bool = False, - only_last_token: bool = False, - attn_algorithm: Optional[str] = None, - include_embeds: bool = False, - ): - output, cache = self.base_model( - x, - mask, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - use_cache=use_cache, - attn_algorithm=attn_algorithm, - ) - - preds = self.head(output) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -class EmbedLLaMA(LLaMA): - # Overrides the forward function of LLaMA to allow returning embedding vectors - def forward( - self, - x, - mask=None, - position_ids=None, - past_key_value_states=None, - use_cache=False, - only_last_token=False, - attn_algorithm=None, - include_embeds=False, - ): - output, cache = self._helper( - x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm - ) - - if only_last_token: - output = output[:, -1, :] - preds = self.shared(output, reverse=True) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -class EmbedMixtral(Mixtral): # FMS impl of Mixtral - # Overrides the forward function of Mixtral to allow returning embedding vectors - def forward( - self, - x, - mask=None, - position_ids=None, - past_key_value_states=None, - use_cache=False, - only_last_token=False, - attn_algorithm=None, - include_embeds=False, - ): - output, cache = self.base_model( - x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm - ) - - if only_last_token: - output = output[:, -1, :] - preds = self.head(output) - - out = [preds] - if use_cache: - out.append(cache) - if include_embeds: - out.append(output) - if len(out) == 1: - return out[0] - return out - - -def _gpt_bigcode_factory_factory(config): - def factory(**kwargs): - return EmbedGPTBigCode(config, **kwargs) - - return factory - - -def _llama_factory_factory(config): - def factory(**kwargs): - return EmbedLLaMA(config, **kwargs) - - return factory - - -def _mixtral_factory_factory(config): - def factory(**kwargs): - return EmbedMixtral(config, **kwargs) - - return factory - - -# example model registrations -register_model( - "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) -) -serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd) - -register_model( - "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) -) -register_model( - "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) -) -serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd) - -register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) -serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd) From 5ee964060c3281c1b53140c7c28be7c99385e35b Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Fri, 18 Oct 2024 18:45:18 +0000 Subject: [PATCH 2/2] removed generate --- speculator/train_speculator_utils.py | 95 +--------------------------- 1 file changed, 1 insertion(+), 94 deletions(-) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 3e807670..53fc7594 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -16,7 +16,7 @@ from fms.models.mixtral import Mixtral, MixtralConfig from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd from fms.utils import serialization, tokenizers -from fms.utils.generation import _make_cache_contiguous +from fms.utils.generation import generate from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader @@ -24,99 +24,6 @@ from fms_fsdp.utils.checkpointing_utils import Checkpointer from fms_fsdp.utils.config_utils import get_model_config - -def generate( - model: Union[Callable, torch.nn.Module], - input_ids: torch.Tensor, - max_seq_len: int = 2048, - max_new_tokens: int = 256, - temperature: float = 1.0, - top_k: int = 10, - do_sample: bool = True, - num_beams: int = 1, - use_cache: bool = False, - contiguous_cache: bool = False, - include_embeds: bool = True, -): - """ - A straightforward copy of the generate method in fms.utils.generation. - The only change is the include_embeds flag, which when true also returns - the embedding vectors corresponding to the tokens in the output sequence. - """ - batched = False - if num_beams != 1: - raise NotImplementedError("generate() does yet not support beam search") - if type(input_ids) == torch.Tensor: - if input_ids.dim() != 1: - batched = True - else: - raise RuntimeError("generate() requires a tensor of token ids as the prefix") - - if not batched: - input_ids = input_ids.unsqueeze(0) - - embeds = None - result = input_ids - next_input = input_ids - kwargs: MutableMapping[str, Any] = dict() - kwargs["past_key_value_states"] = None - kwargs["use_cache"] = use_cache - kwargs["include_embeds"] = include_embeds - - for _ in range(max_new_tokens): - input_ids = next_input[:, -max_seq_len:] - output = model(input_ids, **kwargs) - if not use_cache and not include_embeds: - logits = output - else: - logits = output[0] - if include_embeds: - z = output[-1] - if use_cache: - past_key_value_states = output[1] - # TODO: this should go away when reduce-overhead issues are fixed, or - # maybe could be moved into model code to be more portable. - if contiguous_cache: - kwargs["past_key_value_states"] = _make_cache_contiguous( - past_key_value_states - ) - else: - kwargs["past_key_value_states"] = past_key_value_states - logits = logits[:, -1, :] - - if do_sample: - # get logits from last value in sequence nad scale - logits = logits / temperature - if top_k: - v, _ = torch.topk(logits, top_k) - logits[logits < v[:, [-1]]] = -float("inf") - - probs = F.softmax(logits, dim=-1) - next_val = torch.multinomial(probs, num_samples=1) - else: - next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t() - - result = torch.cat((result, next_val), dim=-1) - - if use_cache: - next_input = next_val - else: - next_input = result - - if include_embeds: - if embeds is None: - embeds = z - else: - embeds = torch.cat((embeds, z), dim=-2) - - if not batched: - result = result[0] - - if include_embeds: - return result, embeds - - return result - class HiddenStatesExtractor(nn.Module): def __init__(self, headless_model: Unionp[nn.Module, Callable], head: nn.Linear):