-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] TransformerDecoder refactor #1017
Comments
Some initial comments/opinions after my first pass through the RFC
|
Thanks for your thoughts @pbontrager!
I like how this looks - it makes sense since heads are just additional layers. Some contra points:
Are you suggesting something like this? class MultiOutputHead():
def __init__(self, layers: ModuleDict):
self.layers = layers
def forward(decoder_output):
# return dict or tuple or dataclass? My thoughts here are similar to above: typing is now abstract, and we're trading some modularity here. Heads can now be composed together arbitrarily, but it might be difficult to deal with different heads using different inputs. I don't actually have a use-case in mind, so perhaps it's not worth trying to over-engineer here. I'm keeping in mind we've also been talking about flexibility in the outputs from RE: Checkpointing - I currently handle checkpointing kinda similar to LoRA checkpointing - the base model with the language head gets saved by default, and value head weights are stored and loaded seperately during training.
Sounds good to me! Hope my points make sense, and that I haven't misinterpreted or missed something obvious. |
This is an awesome RFC! Overall I like the direction here, just chiming in on a few miscellaneous points raised by both you and @pbontrager.
Is this part of the changes you're proposing here? Or would it be done separately? (If not strictly needed for PPO I'd say it's fine to save for a follow-up, but I know this was one of the things @kartikayk was looking at with this refactor.) If we do want to support it here, I feel like a bool
nitpicking but I assume you mean
This makes sense to me.
I don't like the For now my proposal for a compromise is to type
Are you talking about naming of just the head or the full model + head? If the full model I think
Some people are morally opposed to the usage of lambda functions, but I am not one of them. I also think it should be equivalent from the perspective of FSDP (as long as we are not creating a separate set of weights and tying them together it should be fine).
Another path is to just infer the version and remap during checkpoint load. So here we could add something like
A bit more of a black box but imo the checkpointer is kind of that already anyways. Separately we should come up with a proper definition of checkpoint versioning along with support/deprecation model. |
Thanks so much for your comments @ebsmothers.
I generally agree here. Your suggestion sounds straightforward and if @kartikayk is keen to see it here while I'm making other changes I'm happy to include. If I gather what you're suggesting, it's that we have the base I think this abstraction keeps things simple - I'd imagine typing in
The reason I raised this is because not every head results in a language model per-say, but the added complexity of different model types perhaps isn't worth the nomenclature, unless we use something more general like
This makes sense to me.
I think keeping things as simple and easily-understandable as possible perhaps lends well to making the codebase extensible for user-specific purposes, rather than complex abstractions to minimise code changes, so I'd agree here.
I'd lean towards this solution as it feels slightly more user friendly. |
Let me express my humble opinion on the subject, as my issue was mentioned in the initial post. First of all, it seems to me that the issue of backward compatibility should not be essential here, since the project is in the very early stages of development and is not yet very popular. If there is a time to make backward-breaking changes, it is now. Secondly, this library is being developed as part of the pytorch project, so it seems reasonable to expect it to follow the conventions of its older big brother first, rather than outside projects like huggingface transformers. Based on the above, doesn't it seem that the |
@marcinwazny thanks for weighing in here. Re BC-breaking changes I completely agree, it's inevitable that they will occur at this stage in the project. At the same time we want to make sure that we do not leave the users we do have high and dry and without a clear path forward. Re your suggestion to move |
I think I'm on board with this plan! I just want to +1 again having an option for returning internal state. Also for the checkpoint BC problem, I'm fine with both solutions, the main advantage I see with forcing the user to upgrade their checkpoint is so we don't have to support the old checkpoints for very long. |
Thanks @pbontrager! I'm happy with a simple solution like @ebsmothers suggested, using I see both sides RE: backwards compatibility. It'll likely only be one major-release worth of BC we'd need to support anyway, so I'd vote for the simplest option with the easiest UX. |
If all is well with the plan I'll put up a PR for this at some point soon. thanks for all the feedback : ) |
Closing since this will be addressed by #1224, and mutli-output heads are no longer needed. |
TransformerDecoder
RefactorAuthors:
with input from:
Summary
Refactoring
TransformerDecoder
to offer additional flexibility for new use-cases.Motivation/Prior art
TransformerDecoder
#968 - not sure if this is in scope.Currently,
TransformerDecoder
can only be used for language-modelling tasks. There is interest in additional use-cases, such as:AutoModelForSequenceClassification
AutoModelForCausalLMWithValueHead
lm_human_preference_details
Such a refactor could allow users to easily adapt a transformer backbone for a variety of down-stream tasks;
lm_human_preference_details
demonstrates how HF's transformer backbone can be extended in just 8 lines. While this refactor initially targets recipes which will be provided within Torchtune, such as PPO, or sequence-classification training recipes (e.g. for reward models), it would allow users to write custom recipes for many fine-tuning tasks whilst utilising underlying Torchtune features.Proposed Implementation
A small-scale implementation for Mistral models exists in this draft PR. In summary:
TransformerDecoder
will refer to the underlying transformer backbone agnostic to its downstream task. It will return, by default, the final hidden layer as an output of shape[batch_size, sequence_len, embed_dim]
.TransformerDecoder
could support returning hidden states from arbitrary layers (or other useful outputs). Some input on how we allow users to specify this would be helpful. We probably just want to return the last hidden state by default.TransformerLM
as so:TransformerLM
instead ofTransformerDecoder
. Component builders look like:*
mistral_classifier
should now return an instance ofTransformerClassifier
.* Gemma models define a
GemmaTransformerDecoder
which has a unique output projection, but shares the underlying logic of aTransformerDecoder
. We can go two routes here:TransformerLM
accepts aUnion[nn.Module, Callable[torch.tensor]]
(or even justCallable[torch.tensor]
) asoutput
. Then, the Gemma component builder is:GemmaTransformerLM
which looks like:Input on how this affects FDSP would be appreciated.
Components in the codebase I estimate will be impacted, and changes necessary, include:
torchtune.models.convert_weights.py
_FROM_META
and_FROM_HF
should prependdecoder
to destination keys e.g."model.layers.{}.self_attn.q_proj.weight": "decoder.layers.{}.attn.q_proj.weight"
, instead of"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight"
.**TransformerDecoder
as a complete language modelling transformer should now useTransformerLM
. A quick search in VScode shows ~100 references.TransformerLM
- input appreciated.TransformerDecoder
to test functionality without output projections should be added.** A note on backwards compatibility
Users who have previous trained models with
TransformerDecoders
will have checkpoints saved with dict keys in the original format (without thedecoder
prefix). Am I right in thinking they're going to have issues loading these checkpoints into our new models? This could be a pretty disruptive change - some users will have spent a lot of resources fine-tuning their models.Could we provide some well-documented deprecation support for converting state dicts until some release version?
The text was updated successfully, but these errors were encountered: