Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HF PretrainedModel loading for speculative model training #122

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions speculator/train_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import torch
import torch.optim as optim
from fms.models import get_model
from fms.models.gpt_bigcode import GPTBigCode
from fms.models.llama import LLaMABlock
from fms.models.mixtral import Mixtral
from fms.utils import generation, tokenizers
from fms_extras.models.speculator import MLPSpeculator # type: ignore
from torch import distributed as dist
Expand All @@ -24,8 +26,7 @@
setup,
setup_environ_flags,
)
from speculator.train_speculator_utils import train_speculator

from speculator.train_speculator_utils import train_speculator, HiddenStatesExtractor

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -159,6 +160,17 @@ def main(**kwargs):
),
)

if isinstance(model, LLaMA):
headless_model = model._helper
head = model.shared.head
elif isinstance(model, (GPTBigCode, Mixtral)):
headless_model = model.base_model
head = model.head
else:
raise ValueError("speculative training currently only supports LLaMA, GPTBigCode, and Mixtral architectures")

model = HiddenStatesExtractor(headless_model, head)

if rank == 0:
print(f"{time.time()}")
print(model.config)
Expand Down
252 changes: 29 additions & 223 deletions speculator/train_speculator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,106 +16,45 @@
from fms.models.mixtral import Mixtral, MixtralConfig
from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd
from fms.utils import serialization, tokenizers
from fms.utils.generation import _make_cache_contiguous
from fms.utils.generation import generate
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from fms_fsdp.config import train_config
from fms_fsdp.utils.checkpointing_utils import Checkpointer
from fms_fsdp.utils.config_utils import get_model_config

class HiddenStatesExtractor(nn.Module):

def generate(
model: Union[Callable, torch.nn.Module],
input_ids: torch.Tensor,
max_seq_len: int = 2048,
max_new_tokens: int = 256,
temperature: float = 1.0,
top_k: int = 10,
do_sample: bool = True,
num_beams: int = 1,
use_cache: bool = False,
contiguous_cache: bool = False,
include_embeds: bool = True,
):
"""
A straightforward copy of the generate method in fms.utils.generation.
The only change is the include_embeds flag, which when true also returns
the embedding vectors corresponding to the tokens in the output sequence.
"""
batched = False
if num_beams != 1:
raise NotImplementedError("generate() does yet not support beam search")
if type(input_ids) == torch.Tensor:
if input_ids.dim() != 1:
batched = True
else:
raise RuntimeError("generate() requires a tensor of token ids as the prefix")

if not batched:
input_ids = input_ids.unsqueeze(0)

embeds = None
result = input_ids
next_input = input_ids
kwargs: MutableMapping[str, Any] = dict()
kwargs["past_key_value_states"] = None
kwargs["use_cache"] = use_cache
kwargs["include_embeds"] = include_embeds

for _ in range(max_new_tokens):
input_ids = next_input[:, -max_seq_len:]
output = model(input_ids, **kwargs)
if not use_cache and not include_embeds:
logits = output
else:
logits = output[0]
if include_embeds:
z = output[-1]
if use_cache:
past_key_value_states = output[1]
# TODO: this should go away when reduce-overhead issues are fixed, or
# maybe could be moved into model code to be more portable.
if contiguous_cache:
kwargs["past_key_value_states"] = _make_cache_contiguous(
past_key_value_states
)
else:
kwargs["past_key_value_states"] = past_key_value_states
logits = logits[:, -1, :]

if do_sample:
# get logits from last value in sequence nad scale
logits = logits / temperature
if top_k:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float("inf")

probs = F.softmax(logits, dim=-1)
next_val = torch.multinomial(probs, num_samples=1)
else:
next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t()
def __init__(self, headless_model: Unionp[nn.Module, Callable], head: nn.Linear):
self.headless_model = headless_model
self.head = head
self.hidden_states_output = None

# making the assumption this is used with generate, which passes kwargs
def forward(self, *args, **kwargs):
# reset on prefill
if kwargs.get("past_key_value_states", None):
self.hidden_states_output = None

result = torch.cat((result, next_val), dim=-1)
hidden_states, cache = self.headless_model(*args, **kwargs)

if use_cache:
next_input = next_val
if kwargs.get("only_last_token", None):
hidden_states = hidden_states[:, -1, :]

if self.hidden_states_output is None:
self.hidden_states_output = hidden_states
else:
next_input = result
self.hidden_states_output = torch.cat((self.hidden_states_output, hidden_states), dim=-2)

if include_embeds:
if embeds is None:
embeds = z
else:
embeds = torch.cat((embeds, z), dim=-2)
logits = self.head(hidden_states)

if not batched:
result = result[0]
if kwargs.get("use_cache", None):
return logits, cache
else:
return logits

if include_embeds:
return result, embeds

return result


# Stage 1 training
Expand Down Expand Up @@ -150,11 +89,12 @@ def stage1_loss(
Returns: scalar loss value, updated ddp stats, number of tokens in input
"""
with torch.no_grad():
_, embeds = model(
_ = model(
base_model_input[:, : -speculator.n_predict - 1],
include_embeds=True,
use_cache=False,
)
embeds = model.hidden_states_output
if cfg.sharding_strategy == "tp":
embeds = embeds.chunk(base_model_mesh["tp"].size())[
base_model_mesh["tp"].get_local_rank()
Expand Down Expand Up @@ -211,15 +151,16 @@ def stage2_loss(
base_model_input = base_model_input[
:, : cfg.stage2_prompt_length * grow_factor
].reshape(base_model_input.size(0) * grow_factor, cfg.stage2_prompt_length)
targs, embeds = generate(

targs = generate(
model,
base_model_input,
cfg.seq_length,
cfg.stage2_seq_length,
do_sample=True,
use_cache=True,
include_embeds=True,
)
embeds = model.hidden_states_output

if cfg.sharding_strategy == "tp":
targs = targs.chunk(base_model_mesh["tp"].size())[
Expand Down Expand Up @@ -431,138 +372,3 @@ def train_speculator(
tokens_seen=elapsed_tokens + n_tok,
is_compiled=cfg.use_torch_compile,
)


class EmbedGPTBigCode(GPTBigCode):
# Overrides the forward function of GPTBigCode to allow returning embedding vectors
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
only_last_token: bool = False,
attn_algorithm: Optional[str] = None,
include_embeds: bool = False,
):
output, cache = self.base_model(
x,
mask,
position_ids=position_ids,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
attn_algorithm=attn_algorithm,
)

preds = self.head(output)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


class EmbedLLaMA(LLaMA):
# Overrides the forward function of LLaMA to allow returning embedding vectors
def forward(
self,
x,
mask=None,
position_ids=None,
past_key_value_states=None,
use_cache=False,
only_last_token=False,
attn_algorithm=None,
include_embeds=False,
):
output, cache = self._helper(
x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
)

if only_last_token:
output = output[:, -1, :]
preds = self.shared(output, reverse=True)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


class EmbedMixtral(Mixtral): # FMS impl of Mixtral
# Overrides the forward function of Mixtral to allow returning embedding vectors
def forward(
self,
x,
mask=None,
position_ids=None,
past_key_value_states=None,
use_cache=False,
only_last_token=False,
attn_algorithm=None,
include_embeds=False,
):
output, cache = self.base_model(
x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm
)

if only_last_token:
output = output[:, -1, :]
preds = self.head(output)

out = [preds]
if use_cache:
out.append(cache)
if include_embeds:
out.append(output)
if len(out) == 1:
return out[0]
return out


def _gpt_bigcode_factory_factory(config):
def factory(**kwargs):
return EmbedGPTBigCode(config, **kwargs)

return factory


def _llama_factory_factory(config):
def factory(**kwargs):
return EmbedLLaMA(config, **kwargs)

return factory


def _mixtral_factory_factory(config):
def factory(**kwargs):
return EmbedMixtral(config, **kwargs)

return factory


# example model registrations
register_model(
"embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config)
)
serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd)

register_model(
"embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b"))
)
register_model(
"embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b"))
)
serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd)

register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig()))
serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd)
Loading