Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Thara Palanivel <[email protected]>
  • Loading branch information
tharapalanivel committed Mar 5, 2024
1 parent 66ea6d7 commit e4927a2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
38 changes: 28 additions & 10 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,39 @@
# Local
from tuning.config import configs, peft_config


def causal_lm_train_kwargs(train_kwargs):
"""Parse the kwargs for a valid train call to a Causal LM."""
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
model_args, data_args, training_args, lora_config, prompt_tuning_config = parser.parse_dict(train_kwargs, allow_extra_keys=True)
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(train_kwargs, allow_extra_keys=True)

# TODO: target_modules doesn't get set probably due to the way dataclass handles
# mutable defaults, needs investigation on better way to handle this
setattr(lora_config, "target_modules", lora_config.__dataclass_fields__.get("target_modules").default_factory())

return model_args, data_args, training_args, lora_config if train_kwargs.get("peft_method")=="lora" else prompt_tuning_config
setattr(
lora_config,
"target_modules",
lora_config.__dataclass_fields__.get("target_modules").default_factory(),
)

return (
model_args,
data_args,
training_args,
lora_config
if train_kwargs.get("peft_method") == "lora"
else prompt_tuning_config,
)
34 changes: 21 additions & 13 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit Tests for SFT Trainer.
"""
"""

# Standard
import os
import tempfile

# Local
from tuning import sft_trainer
from tests.helpers import causal_lm_train_kwargs
from tests.fixtures import CAUSAL_LM_MODEL
# First Party
from tests.data import TWITTER_COMPLAINTS_DATA
from tests.fixtures import CAUSAL_LM_MODEL
from tests.helpers import causal_lm_train_kwargs

# Local
from tuning import sft_trainer

HAPPY_PATH_KWARGS = {"model_name_or_path": CAUSAL_LM_MODEL,
HAPPY_PATH_KWARGS = {
"model_name_or_path": CAUSAL_LM_MODEL,
"data_path": TWITTER_COMPLAINTS_DATA,
"num_train_epochs": 5,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 0.00001,
"learning_rate": 0.00001,
"weight_decay": 0,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine",
Expand All @@ -49,29 +51,35 @@
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
"tokenizer_name_or_path": CAUSAL_LM_MODEL,
"save_strategy":"epoch"}
"save_strategy": "epoch",
}


def test_run_causallm_pt():
"""Check if we can bootstrap and run causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
HAPPY_PATH_KWARGS["output_dir"] = tempdir
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(HAPPY_PATH_KWARGS)
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
HAPPY_PATH_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


def test_run_causallm_lora():
"""Check if we can bootstrap and run causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
HAPPY_PATH_KWARGS["output_dir"] = tempdir
HAPPY_PATH_KWARGS["peft_method"] = "lora"
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(HAPPY_PATH_KWARGS)
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
HAPPY_PATH_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


def _validate_training(tempdir):
assert any(x.startswith('checkpoint-') for x in os.listdir(tempdir))
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
loss_file_path = "{}/train_loss.jsonl".format(tempdir)
assert os.path.exists(loss_file_path) == True
assert os.path.getsize(loss_file_path) > 0
assert os.path.getsize(loss_file_path) > 0

0 comments on commit e4927a2

Please sign in to comment.