diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 5cf31f9..1403732 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -554,15 +554,18 @@ def factory(**kwargs): register_model( "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) ) -serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd) +serialization.register_adapter_step("embedgpt_bigcode", "hf_to_fms", _gptbigcode_hf_sd_to_fms_sd) +serialization.register_adapter("embedgpt_bigcode", "hf", ["hf_to_fms"]) register_model( "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) ) register_model( "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) -) -serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd) + ) +serialization.register_adapter_step("embedllama", "hf_to_fms", _llama_hf_sd_to_fms_sd) +serialization.register_adapter("embedllama", "hf", ["hf_to_fms"]) register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) -serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd) +serialization.register_adapter_step("embedmixtral", "hf_to_fms", _mixtral_hf_sd_to_fms_sd) +serialization.register_adapter("embedmixtral", "hf", ["hf_to_fms"])