diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 74da97c73..989aaa8ca 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -16,7 +16,7 @@ # Third Party from peft import AutoPeftModelForCausalLM from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer import torch @@ -156,16 +156,22 @@ def load( tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) # Apply the configs to the adapter config of this model; if no overrides # are provided, then the context manager doesn't have any effect. - with AdapterConfigPatcher(checkpoint_path, overrides): - try: - peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path) - except OSError as e: - print("Failed to initialize checkpoint model!") - raise e + try: + with AdapterConfigPatcher(checkpoint_path, overrides): + try: + model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path) + except OSError as e: + print("Failed to initialize checkpoint model!") + raise e + except FileNotFoundError: + print("No adapter config found! Loading as a merged model...") + # Unable to find the adapter config; fall back to loading as a merged model + model = AutoModelForCausalLM.from_pretrained(checkpoint_path) + device = "cuda" if torch.cuda.is_available() else None print(f"Inferred device: {device}") - peft_model.to(device) - return cls(peft_model, tokenizer, device) + model.to(device) + return cls(model, tokenizer, device) def run(self, text: str, *, max_new_tokens: int) -> str: """Runs inference on an instance of this model. @@ -198,7 +204,7 @@ def main(): description="Loads a tuned model and runs an inference call(s) through it" ) parser.add_argument( - "--model", help="Path to tuned model to be loaded", required=True + "--model", help="Path to tuned model / merged model to be loaded", required=True ) parser.add_argument( "--out_file", @@ -207,7 +213,7 @@ def main(): ) parser.add_argument( "--base_model_name_or_path", - help="Override for base model to be used [default: value in model adapter_config.json]", + help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]", default=None, ) parser.add_argument(