diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 94198d52e..80a445304 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -57,7 +57,8 @@ # for some reason the CI will raise an import error if we try to import # these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_json.json" + os.path.dirname(__file__), + "../artifacts/testdata/json/twitter_complaints_small.json", ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 32b8735cb..85058e098 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -329,7 +329,7 @@ def train( time.time() - data_preprocessing_time ) - if framework is not None and framework.requires_agumentation: + if framework is not None and framework.requires_augmentation: model, (peft_config,) = framework.augmentation( model, train_args, modifiable_args=(peft_config,) )