Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

guard ckpt imports #2133

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@

from torchtune import training
from torchtune.models import convert_weights
from torchtune.models.clip._convert_weights import clip_text_hf_to_tune
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf
from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf

from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
Expand Down Expand Up @@ -509,17 +506,23 @@ def load_checkpoint(self) -> Dict[str, Any]:
msg="Converting Phi-3 Mini weights from HF format."
"Note that conversion of adapter weights into PEFT format is not supported.",
)
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune

converted_state_dict[training.MODEL_KEY] = phi3_hf_to_tune(
merged_state_dict
)
elif self._model_type == ModelType.REWARD:
from torchtune.rlhf.utils import reward_hf_to_tune

converted_state_dict[training.MODEL_KEY] = reward_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.QWEN2:
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune

converted_state_dict[training.MODEL_KEY] = qwen2_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
Expand Down Expand Up @@ -550,6 +553,8 @@ def load_checkpoint(self) -> Dict[str, Any]:
),
)
elif self._model_type == ModelType.CLIP_TEXT:
from torchtune.models.clip._convert_weights import clip_text_hf_to_tune

converted_state_dict[training.MODEL_KEY] = clip_text_hf_to_tune(
merged_state_dict,
)
Expand Down Expand Up @@ -610,17 +615,23 @@ def save_checkpoint(
# convert the state_dict back to hf format; do this inplace
if not adapter_only:
if self._model_type == ModelType.PHI3_MINI:
from torchtune.models.phi3._convert_weights import phi3_tune_to_hf

state_dict[training.MODEL_KEY] = phi3_tune_to_hf(
state_dict[training.MODEL_KEY]
)
elif self._model_type == ModelType.REWARD:
from torchtune.rlhf.utils import reward_tune_to_hf

state_dict[training.MODEL_KEY] = reward_tune_to_hf(
state_dict[training.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.QWEN2:
from torchtune.models.qwen2._convert_weights import qwen2_tune_to_hf

state_dict[training.MODEL_KEY] = qwen2_tune_to_hf(
state_dict[training.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
Expand Down
Loading