From ff4b43740efc3c59c880bff8b2f1d31fe27ffc05 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 9 Dec 2024 08:30:27 -0800 Subject: [PATCH] guard imports --- .../training/checkpointing/_checkpointer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 298b0b113c..2c2bc1eb57 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -15,10 +15,7 @@ from torchtune import training from torchtune.models import convert_weights -from torchtune.models.clip._convert_weights import clip_text_hf_to_tune -from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf -from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf -from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf + from torchtune.training.checkpointing._utils import ( ADAPTER_CONFIG_FNAME, ADAPTER_MODEL_FNAME, @@ -509,10 +506,14 @@ def load_checkpoint(self) -> Dict[str, Any]: msg="Converting Phi-3 Mini weights from HF format." "Note that conversion of adapter weights into PEFT format is not supported.", ) + from torchtune.models.phi3._convert_weights import phi3_hf_to_tune + converted_state_dict[training.MODEL_KEY] = phi3_hf_to_tune( merged_state_dict ) elif self._model_type == ModelType.REWARD: + from torchtune.rlhf.utils import reward_hf_to_tune + converted_state_dict[training.MODEL_KEY] = reward_hf_to_tune( merged_state_dict, num_heads=self._config["num_attention_heads"], @@ -520,6 +521,8 @@ def load_checkpoint(self) -> Dict[str, Any]: dim=self._config["hidden_size"], ) elif self._model_type == ModelType.QWEN2: + from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune + converted_state_dict[training.MODEL_KEY] = qwen2_hf_to_tune( merged_state_dict, num_heads=self._config["num_attention_heads"], @@ -550,6 +553,8 @@ def load_checkpoint(self) -> Dict[str, Any]: ), ) elif self._model_type == ModelType.CLIP_TEXT: + from torchtune.models.clip._convert_weights import clip_text_hf_to_tune + converted_state_dict[training.MODEL_KEY] = clip_text_hf_to_tune( merged_state_dict, ) @@ -610,10 +615,14 @@ def save_checkpoint( # convert the state_dict back to hf format; do this inplace if not adapter_only: if self._model_type == ModelType.PHI3_MINI: + from torchtune.models.phi3._convert_weights import phi3_tune_to_hf + state_dict[training.MODEL_KEY] = phi3_tune_to_hf( state_dict[training.MODEL_KEY] ) elif self._model_type == ModelType.REWARD: + from torchtune.rlhf.utils import reward_tune_to_hf + state_dict[training.MODEL_KEY] = reward_tune_to_hf( state_dict[training.MODEL_KEY], num_heads=self._config["num_attention_heads"], @@ -621,6 +630,8 @@ def save_checkpoint( dim=self._config["hidden_size"], ) elif self._model_type == ModelType.QWEN2: + from torchtune.models.qwen2._convert_weights import qwen2_tune_to_hf + state_dict[training.MODEL_KEY] = qwen2_tune_to_hf( state_dict[training.MODEL_KEY], num_heads=self._config["num_attention_heads"],