Skip to content

Commit

Permalink
guard imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Dec 9, 2024
1 parent 06a8379 commit ff4b437
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@

from torchtune import training
from torchtune.models import convert_weights
from torchtune.models.clip._convert_weights import clip_text_hf_to_tune
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf
from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf

from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
Expand Down Expand Up @@ -509,17 +506,23 @@ def load_checkpoint(self) -> Dict[str, Any]:
msg="Converting Phi-3 Mini weights from HF format."
"Note that conversion of adapter weights into PEFT format is not supported.",
)
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune

converted_state_dict[training.MODEL_KEY] = phi3_hf_to_tune(
merged_state_dict
)
elif self._model_type == ModelType.REWARD:
from torchtune.rlhf.utils import reward_hf_to_tune

converted_state_dict[training.MODEL_KEY] = reward_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.QWEN2:
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune

converted_state_dict[training.MODEL_KEY] = qwen2_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
Expand Down Expand Up @@ -550,6 +553,8 @@ def load_checkpoint(self) -> Dict[str, Any]:
),
)
elif self._model_type == ModelType.CLIP_TEXT:
from torchtune.models.clip._convert_weights import clip_text_hf_to_tune

converted_state_dict[training.MODEL_KEY] = clip_text_hf_to_tune(
merged_state_dict,
)
Expand Down Expand Up @@ -610,17 +615,23 @@ def save_checkpoint(
# convert the state_dict back to hf format; do this inplace
if not adapter_only:
if self._model_type == ModelType.PHI3_MINI:
from torchtune.models.phi3._convert_weights import phi3_tune_to_hf

state_dict[training.MODEL_KEY] = phi3_tune_to_hf(
state_dict[training.MODEL_KEY]
)
elif self._model_type == ModelType.REWARD:
from torchtune.rlhf.utils import reward_tune_to_hf

state_dict[training.MODEL_KEY] = reward_tune_to_hf(
state_dict[training.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.QWEN2:
from torchtune.models.qwen2._convert_weights import qwen2_tune_to_hf

state_dict[training.MODEL_KEY] = qwen2_tune_to_hf(
state_dict[training.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
Expand Down

0 comments on commit ff4b437

Please sign in to comment.