Skip to content

Commit

Permalink
Enable code formatting (#40)
Browse files Browse the repository at this point in the history
Adds a GitHub workflow so its called when commits pushed.
Also, changes for running the formatter locally are committed
so the workflow will run as expected.

Signed-off-by: Martin Hickey <[email protected]>
  • Loading branch information
hickeyma authored Feb 13, 2024
1 parent ac597d8 commit cef30ea
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 116 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Format

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r setup_requirements.txt
- name: Check Formatting
run: tox -e fmt

54 changes: 40 additions & 14 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
If these things change in the future, we should consider breaking it up.
"""
# Standard
import argparse
import json
import os

# Third Party
from peft import AutoPeftModelForCausalLM
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
import torch


### Utilities
Expand All @@ -30,10 +33,13 @@ class AdapterConfigPatcher:
# When loaded in this block, the config's base_model_name_or_path is "foo"
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
"""

def __init__(self, checkpoint_path: str, overrides: dict):
self.checkpoint_path = checkpoint_path
self.overrides = overrides
self.config_path = AdapterConfigPatcher._locate_adapter_config(self.checkpoint_path)
self.config_path = AdapterConfigPatcher._locate_adapter_config(
self.checkpoint_path
)
# Values that we will patch later on
self.patched_values = {}

Expand All @@ -58,7 +64,7 @@ def _locate_adapter_config(checkpoint_path: str) -> str:
def _apply_config_changes(self, overrides: dict) -> dict:
"""Applies a patch to a config with some override dict, returning the values
that we patched over so that they may be restored later.
Args:
overrides: dict
Overrides to write into the adapter_config.json. Currently, we
Expand Down Expand Up @@ -99,7 +105,9 @@ def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict:
# For now, we only expect to patch the base model; this may change in the future,
# but ensure that anything we are patching is defined in the original config
if not set(overrides.keys()).issubset(set(adapter_config.keys())):
raise KeyError("Adapter config overrides must be set in the config being patched")
raise KeyError(
"Adapter config overrides must be set in the config being patched"
)
return {key: adapter_config[key] for key in overrides}

def __enter__(self):
Expand All @@ -119,7 +127,9 @@ def __init__(self, model, tokenizer, device):
self.device = device

@classmethod
def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "TunedCausalLM":
def load(
cls, checkpoint_path: str, base_model_name_or_path: str = None
) -> "TunedCausalLM":
"""Loads an instance of this model.
Args:
Expand All @@ -138,7 +148,11 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
TunedCausalLM
An instance of this class on which we can run inference.
"""
overrides = {"base_model_name_or_path": base_model_name_or_path} if base_model_name_or_path is not None else {}
overrides = (
{"base_model_name_or_path": base_model_name_or_path}
if base_model_name_or_path is not None
else {}
)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
# Apply the configs to the adapter config of this model; if no overrides
# are provided, then the context manager doesn't have any effect.
Expand All @@ -153,7 +167,6 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
peft_model.to(device)
return cls(peft_model, tokenizer, device)


def run(self, text: str, *, max_new_tokens: int) -> str:
"""Runs inference on an instance of this model.
Expand All @@ -165,13 +178,17 @@ def run(self, text: str, *, max_new_tokens: int) -> str:
Returns:
str
Text generation result.
Text generation result.
"""
tok_res = self.tokenizer(text, return_tensors="pt")
input_ids = tok_res.input_ids.to(self.device)

peft_outputs = self.peft_model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
decoded_result = self.tokenizer.batch_decode(peft_outputs, skip_special_tokens=False)[0]
peft_outputs = self.peft_model.generate(
input_ids=input_ids, max_new_tokens=max_new_tokens
)
decoded_result = self.tokenizer.batch_decode(
peft_outputs, skip_special_tokens=False
)[0]
return decoded_result


Expand All @@ -180,7 +197,9 @@ def main():
parser = argparse.ArgumentParser(
description="Loads a tuned model and runs an inference call(s) through it"
)
parser.add_argument("--model", help="Path to tuned model to be loaded", required=True)
parser.add_argument(
"--model", help="Path to tuned model to be loaded", required=True
)
parser.add_argument(
"--out_file",
help="JSON file to write results to",
Expand All @@ -189,7 +208,7 @@ def main():
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used [default: value in model adapter_config.json]",
default=None
default=None,
)
parser.add_argument(
"--max_new_tokens",
Expand All @@ -199,7 +218,10 @@ def main():
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", help="Text to run inference on")
group.add_argument("--text_file", help="File to be processed where each line is a text to run inference on")
group.add_argument(
"--text_file",
help="File to be processed where each line is a text to run inference on",
)
args = parser.parse_args()
# If we passed a file, check if it exists before doing anything else
if args.text_file and not os.path.isfile(args.text_file):
Expand All @@ -220,7 +242,10 @@ def main():

# TODO: we should add batch inference support
results = [
{"input": text, "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens)}
{
"input": text,
"output": loaded_model.run(text, max_new_tokens=args.max_new_tokens),
}
for text in tqdm(texts)
]

Expand All @@ -230,5 +255,6 @@ def main():

print(f"Exported results to: {args.out_file}")


if __name__ == "__main__":
main()
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# Third Party
from setuptools import find_packages, setup

setup(
name="tuning",
version="0.0.1",
packages=find_packages()
)
setup(name="tuning", version="0.0.1", packages=find_packages())
14 changes: 10 additions & 4 deletions tuning/aim_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# Standard
import os

# Third Party
from aim.hugging_face import AimCallback


def get_aimstack_callback():
# Initialize a new run
aim_server = os.environ.get('AIMSTACK_SERVER')
aim_db = os.environ.get('AIMSTACK_DB')
aim_experiment = os.environ.get('AIMSTACK_EXPERIMENT')
aim_server = os.environ.get("AIMSTACK_SERVER")
aim_db = os.environ.get("AIMSTACK_DB")
aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT")
if aim_experiment is None:
aim_experiment = ""

if aim_server:
aim_callback = AimCallback(repo='aim://'+aim_server+'/', experiment=aim_experiment)
aim_callback = AimCallback(
repo="aim://" + aim_server + "/", experiment=aim_experiment
)
if aim_db:
aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment)
else:
Expand Down
34 changes: 25 additions & 9 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
# Standard
from dataclasses import dataclass, field
from typing import Dict, Optional, Union

# Third Party
import torch
import transformers

DEFAULT_CONTEXT_LENGTH=4096
DEFAULT_OPTIMIZER="adamw_torch"
DEFAULT_CONTEXT_LENGTH = 4096
DEFAULT_OPTIMIZER = "adamw_torch"

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<PAD>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
use_flash_attn: bool = field(
default=True,
metadata={"help": "Use Flash attention v2 from transformers, default is True"}
metadata={"help": "Use Flash attention v2 from transformers, default is True"},
)
torch_dtype: Optional[Union[torch.dtype , str]] = torch.bfloat16
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16


@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data in JSONL format."})
response_template: str = field(default=None, metadata={"help": "Response template, separator to train on completions only"})
dataset_text_field: str = field(default=None, metadata={"help": "Training dataset text field"})
validation_data_path: str = field(default=None, metadata={"help": "Path to the validation data in JSONL format."})
data_path: str = field(
default=None, metadata={"help": "Path to the training data in JSONL format."}
)
response_template: str = field(
default=None,
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None, metadata={"help": "Training dataset text field"}
)
validation_data_path: str = field(
default=None, metadata={"help": "Path to the validation data in JSONL format."}
)


@dataclass
Expand All @@ -35,7 +49,9 @@ class TrainingArguments(transformers.TrainingArguments):
# optim: str = field(default=DEFAULT_OPTIMIZER)
model_max_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
packing: bool = field(
default=False,
Expand Down
17 changes: 11 additions & 6 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# Standard
from dataclasses import dataclass, field
from typing import List


@dataclass
class LoraConfig:
r: int = 8
lora_alpha: int = 32
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"], metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
"end with one of the strings. If the value is [\"all-linear\"], then LORA selects all linear and Conv1D "
"modules except for the output layer."
})
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"],
metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D '
"modules except for the output layer."
},
)
bias = "none"
lora_dropout: float = 0.05

Expand All @@ -19,4 +24,4 @@ class PromptTuningConfig:
prompt_tuning_init: str = "TEXT"
num_virtual_tokens: int = 8
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
tokenizer_name_or_path: str = "llama-7b-hf"
tokenizer_name_or_path: str = "llama-7b-hf"
23 changes: 16 additions & 7 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import transformers
# Standard
from typing import Dict, Sequence
import copy
from tuning.config import configs
import json
import logging

# Third Party
from torch.utils.data import Dataset
import logging
import json
import torch
import transformers

# Local
from tuning.config import configs


def tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
TODO: In the future, make sure we can have vocab size divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
Expand All @@ -24,8 +29,12 @@ def tokenizer_and_embedding_resize(
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data

input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
Loading

0 comments on commit cef30ea

Please sign in to comment.