diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 000000000..d926b1220 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,38 @@ +# Copyright The Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Format + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r setup_requirements.txt + - name: Check Formatting + run: tox -e fmt + diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 491986572..74da97c73 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -8,13 +8,16 @@ If these things change in the future, we should consider breaking it up. """ +# Standard import argparse import json import os + +# Third Party from peft import AutoPeftModelForCausalLM -import torch from tqdm import tqdm from transformers import AutoTokenizer +import torch ### Utilities @@ -30,10 +33,13 @@ class AdapterConfigPatcher: # When loaded in this block, the config's base_model_name_or_path is "foo" peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path) """ + def __init__(self, checkpoint_path: str, overrides: dict): self.checkpoint_path = checkpoint_path self.overrides = overrides - self.config_path = AdapterConfigPatcher._locate_adapter_config(self.checkpoint_path) + self.config_path = AdapterConfigPatcher._locate_adapter_config( + self.checkpoint_path + ) # Values that we will patch later on self.patched_values = {} @@ -58,7 +64,7 @@ def _locate_adapter_config(checkpoint_path: str) -> str: def _apply_config_changes(self, overrides: dict) -> dict: """Applies a patch to a config with some override dict, returning the values that we patched over so that they may be restored later. - + Args: overrides: dict Overrides to write into the adapter_config.json. Currently, we @@ -99,7 +105,9 @@ def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict: # For now, we only expect to patch the base model; this may change in the future, # but ensure that anything we are patching is defined in the original config if not set(overrides.keys()).issubset(set(adapter_config.keys())): - raise KeyError("Adapter config overrides must be set in the config being patched") + raise KeyError( + "Adapter config overrides must be set in the config being patched" + ) return {key: adapter_config[key] for key in overrides} def __enter__(self): @@ -119,7 +127,9 @@ def __init__(self, model, tokenizer, device): self.device = device @classmethod - def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "TunedCausalLM": + def load( + cls, checkpoint_path: str, base_model_name_or_path: str = None + ) -> "TunedCausalLM": """Loads an instance of this model. Args: @@ -138,7 +148,11 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned TunedCausalLM An instance of this class on which we can run inference. """ - overrides = {"base_model_name_or_path": base_model_name_or_path} if base_model_name_or_path is not None else {} + overrides = ( + {"base_model_name_or_path": base_model_name_or_path} + if base_model_name_or_path is not None + else {} + ) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) # Apply the configs to the adapter config of this model; if no overrides # are provided, then the context manager doesn't have any effect. @@ -153,7 +167,6 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned peft_model.to(device) return cls(peft_model, tokenizer, device) - def run(self, text: str, *, max_new_tokens: int) -> str: """Runs inference on an instance of this model. @@ -165,13 +178,17 @@ def run(self, text: str, *, max_new_tokens: int) -> str: Returns: str - Text generation result. + Text generation result. """ tok_res = self.tokenizer(text, return_tensors="pt") input_ids = tok_res.input_ids.to(self.device) - peft_outputs = self.peft_model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens) - decoded_result = self.tokenizer.batch_decode(peft_outputs, skip_special_tokens=False)[0] + peft_outputs = self.peft_model.generate( + input_ids=input_ids, max_new_tokens=max_new_tokens + ) + decoded_result = self.tokenizer.batch_decode( + peft_outputs, skip_special_tokens=False + )[0] return decoded_result @@ -180,7 +197,9 @@ def main(): parser = argparse.ArgumentParser( description="Loads a tuned model and runs an inference call(s) through it" ) - parser.add_argument("--model", help="Path to tuned model to be loaded", required=True) + parser.add_argument( + "--model", help="Path to tuned model to be loaded", required=True + ) parser.add_argument( "--out_file", help="JSON file to write results to", @@ -189,7 +208,7 @@ def main(): parser.add_argument( "--base_model_name_or_path", help="Override for base model to be used [default: value in model adapter_config.json]", - default=None + default=None, ) parser.add_argument( "--max_new_tokens", @@ -199,7 +218,10 @@ def main(): ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--text", help="Text to run inference on") - group.add_argument("--text_file", help="File to be processed where each line is a text to run inference on") + group.add_argument( + "--text_file", + help="File to be processed where each line is a text to run inference on", + ) args = parser.parse_args() # If we passed a file, check if it exists before doing anything else if args.text_file and not os.path.isfile(args.text_file): @@ -220,7 +242,10 @@ def main(): # TODO: we should add batch inference support results = [ - {"input": text, "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens)} + { + "input": text, + "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens), + } for text in tqdm(texts) ] @@ -230,5 +255,6 @@ def main(): print(f"Exported results to: {args.out_file}") + if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index c34a011f4..ae71369c0 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,4 @@ +# Third Party from setuptools import find_packages, setup -setup( - name="tuning", - version="0.0.1", - packages=find_packages() -) +setup(name="tuning", version="0.0.1", packages=find_packages()) diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py index 44aa46748..6ee617a42 100644 --- a/tuning/aim_loader.py +++ b/tuning/aim_loader.py @@ -1,16 +1,22 @@ +# Standard import os + +# Third Party from aim.hugging_face import AimCallback + def get_aimstack_callback(): # Initialize a new run - aim_server = os.environ.get('AIMSTACK_SERVER') - aim_db = os.environ.get('AIMSTACK_DB') - aim_experiment = os.environ.get('AIMSTACK_EXPERIMENT') + aim_server = os.environ.get("AIMSTACK_SERVER") + aim_db = os.environ.get("AIMSTACK_DB") + aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") if aim_experiment is None: aim_experiment = "" if aim_server: - aim_callback = AimCallback(repo='aim://'+aim_server+'/', experiment=aim_experiment) + aim_callback = AimCallback( + repo="aim://" + aim_server + "/", experiment=aim_experiment + ) if aim_db: aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) else: diff --git a/tuning/config/configs.py b/tuning/config/configs.py index cd88ec672..0b0a8fb67 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -1,10 +1,13 @@ +# Standard from dataclasses import dataclass, field from typing import Dict, Optional, Union + +# Third Party import torch import transformers -DEFAULT_CONTEXT_LENGTH=4096 -DEFAULT_OPTIMIZER="adamw_torch" +DEFAULT_CONTEXT_LENGTH = 4096 +DEFAULT_OPTIMIZER = "adamw_torch" IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "" @@ -12,21 +15,32 @@ DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" + @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") use_flash_attn: bool = field( default=True, - metadata={"help": "Use Flash attention v2 from transformers, default is True"} + metadata={"help": "Use Flash attention v2 from transformers, default is True"}, ) - torch_dtype: Optional[Union[torch.dtype , str]] = torch.bfloat16 + torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16 + @dataclass class DataArguments: - data_path: str = field(default=None, metadata={"help": "Path to the training data in JSONL format."}) - response_template: str = field(default=None, metadata={"help": "Response template, separator to train on completions only"}) - dataset_text_field: str = field(default=None, metadata={"help": "Training dataset text field"}) - validation_data_path: str = field(default=None, metadata={"help": "Path to the validation data in JSONL format."}) + data_path: str = field( + default=None, metadata={"help": "Path to the training data in JSONL format."} + ) + response_template: str = field( + default=None, + metadata={"help": "Response template, separator to train on completions only"}, + ) + dataset_text_field: str = field( + default=None, metadata={"help": "Training dataset text field"} + ) + validation_data_path: str = field( + default=None, metadata={"help": "Path to the validation data in JSONL format."} + ) @dataclass @@ -35,7 +49,9 @@ class TrainingArguments(transformers.TrainingArguments): # optim: str = field(default=DEFAULT_OPTIMIZER) model_max_length: int = field( default=DEFAULT_CONTEXT_LENGTH, - metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, ) packing: bool = field( default=False, diff --git a/tuning/config/peft_config.py b/tuning/config/peft_config.py index 6865603f3..a3d30c763 100644 --- a/tuning/config/peft_config.py +++ b/tuning/config/peft_config.py @@ -1,15 +1,20 @@ +# Standard from dataclasses import dataclass, field from typing import List + @dataclass class LoraConfig: r: int = 8 lora_alpha: int = 32 - target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"], metadata={ - "help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or " - "end with one of the strings. If the value is [\"all-linear\"], then LORA selects all linear and Conv1D " - "modules except for the output layer." - }) + target_modules: List[str] = field( + default_factory=lambda: ["q_proj", "v_proj"], + metadata={ + "help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or " + 'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D ' + "modules except for the output layer." + }, + ) bias = "none" lora_dropout: float = 0.05 @@ -19,4 +24,4 @@ class PromptTuningConfig: prompt_tuning_init: str = "TEXT" num_virtual_tokens: int = 8 prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:" - tokenizer_name_or_path: str = "llama-7b-hf" \ No newline at end of file + tokenizer_name_or_path: str = "llama-7b-hf" diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index e7d03a003..3a8a288f3 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -1,12 +1,17 @@ -import transformers +# Standard from typing import Dict, Sequence import copy -from tuning.config import configs +import json +import logging +# Third Party from torch.utils.data import Dataset -import logging -import json import torch +import transformers + +# Local +from tuning.config import configs + def tokenizer_and_embedding_resize( special_tokens_dict: Dict, @@ -14,7 +19,7 @@ def tokenizer_and_embedding_resize( model: transformers.PreTrainedModel, ): """Resize tokenizer and embedding. - + TODO: In the future, make sure we can have vocab size divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) @@ -24,8 +29,12 @@ def tokenizer_and_embedding_resize( input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 6d883a336..b8ec5df4f 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,36 +1,51 @@ -import os +# Standard from typing import Optional, Union +import os +# Third Party +from peft.utils.other import fsdp_auto_wrap_policy +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GPT2Tokenizer, + GPTNeoXTokenizerFast, + LlamaTokenizer, + LlamaTokenizerFast, + TrainerCallback, +) +from transformers.utils import logging +from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire -from peft.utils.other import fsdp_auto_wrap_policy import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer -from transformers.utils import logging -from transformers import TrainerCallback -from trl import SFTTrainer, DataCollatorForCompletionOnlyLM + +# Local from tuning.aim_loader import get_aimstack_callback from tuning.config import configs, peft_config from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype + class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): - checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + checkpoint_path = os.path.join( + args.output_dir, f"checkpoint-{state.global_step}" + ) kwargs["model"].save_pretrained(checkpoint_path) if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, train_args: configs.TrainingArguments, - peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None, + peft_config: Optional[ + Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] + ] = None, ): """Call the SFTTrainer @@ -48,15 +63,19 @@ def train( logger = logging.get_logger("sft_trainer") # Validate parameters - if (not isinstance(train_args.num_train_epochs, float)) or (train_args.num_train_epochs <= 0): + if (not isinstance(train_args.num_train_epochs, float)) or ( + train_args.num_train_epochs <= 0 + ): raise ValueError("num_train_epochs has to be an integer/float >= 1") - if (not isinstance(train_args.gradient_accumulation_steps , int)) or (train_args.gradient_accumulation_steps <= 0): + if (not isinstance(train_args.gradient_accumulation_steps, int)) or ( + train_args.gradient_accumulation_steps <= 0 + ): raise ValueError("gradient_accumulation_steps has to be an integer >= 1") # make sure to unset FSDP args when running on single gpu if not run_distributed: train_args.fsdp = "" - train_args.fsdp_config = {'xla':False} + train_args.fsdp_config = {"xla": False} task_type = "CAUSAL_LM" model = AutoModelForCausalLM.from_pretrained( @@ -65,43 +84,53 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - + peft_config = get_hf_peft_config(task_type, peft_config) model.gradient_checkpointing_enable() # TODO: Move these to a config as well tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - cache_dir=train_args.cache_dir, - use_fast = True + model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True ) # TODO: understand if we need to hardcode these here or just use defaults in model - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - tokenizer.add_special_tokens({ - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - }) - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, GPT2Tokenizer): - tokenizer.add_special_tokens({ - "pad_token": "", - }) + if isinstance(tokenizer, LlamaTokenizer) or isinstance( + tokenizer, LlamaTokenizerFast + ): + tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance( + tokenizer, GPT2Tokenizer + ): + tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) """TODO: near term - how response template ids are parsed out needs to be cleaned. The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found. We will create issue to clean this out after we discuss data formats and collators we will support """ - response_template_ids = tokenizer.encode(data_args.response_template, add_special_tokens=False)[2:] - # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length + response_template_ids = tokenizer.encode( + data_args.response_template, add_special_tokens=False + )[2:] + # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length # as in current main. We need to change name of this parameter we expose to users. model_max_length = min(train_args.model_max_length, tokenizer.model_max_length) logger.info(f"Model max length {model_max_length}") if train_args.model_max_length > tokenizer.model_max_length: - logger.warning(f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}") - + logger.warning( + f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}" + ) + # TODO: we need to change this, perhaps follow what open instruct does? special_tokens_dict = dict() if tokenizer.pad_token is None: @@ -124,26 +153,29 @@ def train( tokenizer=tokenizer, model=model, ) - + # load the data by parsing JSON # TODO: update arg from data_path to training_data_path since we also have validation_data_path data_files = {"train": data_args.data_path} if data_args.validation_data_path: data_files["validation"] = data_args.validation_data_path - format_dataset = lambda example : {f"{data_args.dataset_text_field}" : example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token} + format_dataset = lambda example: { + f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"] + + tokenizer.eos_token + } - json_dataset = datasets.load_dataset('json', data_files=data_files) - formatted_train_dataset = json_dataset['train'].map(format_dataset) + json_dataset = datasets.load_dataset("json", data_files=data_files) + formatted_train_dataset = json_dataset["train"].map(format_dataset) logger.info(f"Training dataset length is {len(formatted_train_dataset)}") formatted_validation_dataset = None if data_args.validation_data_path: - formatted_validation_dataset = json_dataset['validation'].map(format_dataset) + formatted_validation_dataset = json_dataset["validation"].map(format_dataset) logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}") aim_callback = get_aimstack_callback() - callbacks=[aim_callback,PeftSavingCallback()] + callbacks = [aim_callback, PeftSavingCallback()] if train_args.packing: logger.info("Packing is set to True") @@ -152,14 +184,22 @@ def train( else: logger.info("Packing is set to False") if data_args.response_template is None: - logger.error("Error, response template is None, needs to be set for training") + logger.error( + "Error, response template is None, needs to be set for training" + ) exit(-1) - + if data_args.dataset_text_field is None: - logger.error("Error, dataset_text_field is None, needs to be set for training") + logger.error( + "Error, dataset_text_field is None, needs to be set for training" + ) exit(-1) - - data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX) + + data_collator = DataCollatorForCompletionOnlyLM( + response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) packing = False trainer = SFTTrainer( @@ -177,25 +217,45 @@ def train( ) if run_distributed and peft_config is not None: - trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) + trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( + model + ) trainer.train() def main(**kwargs): - parser = transformers.HfArgumentParser(dataclass_types=(configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig)) - parser.add_argument('--peft_method', type=str.lower, choices=['pt', 'lora', None, 'none'], default="pt") - model_args, data_args, training_args, lora_config, prompt_tuning_config, peft_method, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + ) + ) + parser.add_argument( + "--peft_method", + type=str.lower, + choices=["pt", "lora", None, "none"], + default="pt", + ) + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + peft_method, + _, + ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) if peft_method.peft_method == "lora": - tune_config=lora_config - elif peft_method.peft_method =="pt": - tune_config=prompt_tuning_config + tune_config = lora_config + elif peft_method.peft_method == "pt": + tune_config = prompt_tuning_config else: - tune_config=None + tune_config = None train(model_args, data_args, training_args, tune_config) + if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index d9579432d..58896c1f9 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -1,7 +1,12 @@ -from peft import LoraConfig, PromptTuningConfig +# Standard from dataclasses import asdict -from tuning.config import peft_config +# Third Party +from peft import LoraConfig, PromptTuningConfig + +# Local +from tuning.config import peft_config + def update_config(config, **kwargs): if isinstance(config, (tuple, list)): @@ -21,17 +26,22 @@ def update_config(config, **kwargs): # In case of specialized config we can warm user print(f"Warning: {config_name} does not accept parameter: {k}") + def create_tuning_config(peft_method, **kwargs): """Create peft_config Tuning config - Args: - peft_method: str - lora, pt or None - kawrgs: parameters to initialize library configs with - Return: - peft_config.LoraConfig | peft_config.PromptTuningConfig | None + Args: + peft_method: str + lora, pt or None + kawrgs: parameters to initialize library configs with + Return: + peft_config.LoraConfig | peft_config.PromptTuningConfig | None """ - assert peft_method in [None, "lora", "pt", "None"], \ - f"peft config {peft_method} not defined in peft.py" + assert peft_method in [ + None, + "lora", + "pt", + "None", + ], f"peft config {peft_method} not defined in peft.py" if peft_method == "lora": tune_config = peft_config.LoraConfig() update_config(tune_config, **kwargs) @@ -39,16 +49,16 @@ def create_tuning_config(peft_method, **kwargs): tune_config = peft_config.PromptTuningConfig() update_config(tune_config, **kwargs) else: - tune_config = None # full parameter tuning + tune_config = None # full parameter tuning return tune_config def get_hf_peft_config(task_type, tuning_config): """Return HF PEFT config for tuning based on type of tuning config passed - Args: - task_type: str - tuning_config: peft_config.LoraConfig | peft_config.PromptTuningConfig | None - Return: HF PEFT config or None + Args: + task_type: str + tuning_config: peft_config.LoraConfig | peft_config.PromptTuningConfig | None + Return: HF PEFT config or None """ if isinstance(tuning_config, peft_config.LoraConfig): lora_config = asdict(tuning_config) @@ -56,7 +66,9 @@ def get_hf_peft_config(task_type, tuning_config): lora_config["target_modules"] = "all-linear" hf_peft_config = LoraConfig(task_type=task_type, **lora_config) elif isinstance(tuning_config, peft_config.PromptTuningConfig): - hf_peft_config = PromptTuningConfig(task_type=task_type, **asdict(tuning_config)) + hf_peft_config = PromptTuningConfig( + task_type=task_type, **asdict(tuning_config) + ) else: hf_peft_config = None # full parameter tuning diff --git a/tuning/utils/data_type_utils.py b/tuning/utils/data_type_utils.py index 26c33438b..42b058cde 100644 --- a/tuning/utils/data_type_utils.py +++ b/tuning/utils/data_type_utils.py @@ -7,6 +7,7 @@ logger = logging.get_logger("data_utils") + def str_to_torch_dtype(dtype_str: str) -> torch.dtype: """Given a string representation of a Torch data type, convert it to the actual torch dtype. @@ -41,4 +42,4 @@ def get_torch_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: return dtype # TODO - If None/empty str was provided, read it from model config? # Otherwise convert it from a string - return str_to_torch_dtype(dtype) \ No newline at end of file + return str_to_torch_dtype(dtype)