Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HF PretrainedModel loading for speculative model training #122

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

JRosenkranz
Copy link
Collaborator

This PR enables HF PretrainedModel loading. To use this feature, simply set the architecture to "hf_pretrained", and the variant to a huggingface variant (model_id). This was enabled my removing the need to create special adapters, by wrapping a model in a HiddenStatesExtractor (extracts hidden states from base model). With this new wrapper, the adapters and overridden model classes that used include_embeds were not required, as well as generate could be used from fms main

…te from fms directly to extract hidden states; removed unnecessary classes
@JRosenkranz JRosenkranz self-assigned this Oct 18, 2024
@daviswer
Copy link
Collaborator

Nice, this does make more sense once models are partitioned into headless/head components!

@sahilsuneja1
Copy link
Collaborator

sahilsuneja1 commented Oct 21, 2024

Couldn't follow the reset logic. Rest everything looks good!

@JRosenkranz
Copy link
Collaborator Author

Couldn't follow the reset logic. Rest everything looks good!

Resetting always occurs on prefill. Past_key_value_states=None on every prefill (stage 1 always has past_key_value_states=None, stage 2 sets past_key_value_states to None on the first call to the model). This way, we always get the latest hidden_states_output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants