Skip to content

Commit

Permalink
Merge pull request #43 from alex-jw-brooks/local_inference_merged_models
Browse files Browse the repository at this point in the history
Local inference merged models
  • Loading branch information
anhuong authored Feb 20, 2024
2 parents 6eba340 + c2b8c76 commit 24e7385
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Third Party
from peft import AutoPeftModelForCausalLM
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


Expand Down Expand Up @@ -156,16 +156,22 @@ def load(
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
# Apply the configs to the adapter config of this model; if no overrides
# are provided, then the context manager doesn't have any effect.
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
peft_model.to(device)
return cls(peft_model, tokenizer, device)
model.to(device)
return cls(model, tokenizer, device)

def run(self, text: str, *, max_new_tokens: int) -> str:
"""Runs inference on an instance of this model.
Expand Down Expand Up @@ -198,7 +204,7 @@ def main():
description="Loads a tuned model and runs an inference call(s) through it"
)
parser.add_argument(
"--model", help="Path to tuned model to be loaded", required=True
"--model", help="Path to tuned model / merged model to be loaded", required=True
)
parser.add_argument(
"--out_file",
Expand All @@ -207,7 +213,7 @@ def main():
)
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used [default: value in model adapter_config.json]",
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
default=None,
)
parser.add_argument(
Expand Down

0 comments on commit 24e7385

Please sign in to comment.