Skip to content

Commit

Permalink
Rename adapters to work correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
ani300 committed Dec 17, 2024
1 parent e4c4ff4 commit 44a3c6d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions speculator/train_speculator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,15 +554,18 @@ def factory(**kwargs):
register_model(
"embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config)
)
serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd)
serialization.register_adapter_step("embedgpt_bigcode", "hf_to_fms", _gptbigcode_hf_sd_to_fms_sd)
serialization.register_adapter("embedgpt_bigcode", "hf", ["hf_to_fms"])

register_model(
"embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b"))
)
register_model(
"embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b"))
)
serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd)
)
serialization.register_adapter_step("embedllama", "hf_to_fms", _llama_hf_sd_to_fms_sd)
serialization.register_adapter("embedllama", "hf", ["hf_to_fms"])

register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig()))
serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd)
serialization.register_adapter_step("embedmixtral", "hf_to_fms", _mixtral_hf_sd_to_fms_sd)
serialization.register_adapter("embedmixtral", "hf", ["hf_to_fms"])

0 comments on commit 44a3c6d

Please sign in to comment.