From d694e1804682ea61db17091b72109befa46b50e4 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 6 Jan 2025 15:07:51 -0500 Subject: [PATCH 1/3] fix: function name from requires_agumentation to requires_augmentation Signed-off-by: Will Johnson --- tuning/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2afdd2dac..fa9aedcfa 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -315,7 +315,7 @@ def train( time.time() - data_preprocessing_time ) - if framework is not None and framework.requires_agumentation: + if framework is not None and framework.requires_augmentation: model, (peft_config,) = framework.augmentation( model, train_args, modifiable_args=(peft_config,) ) From 37cf836a8663839ae814dd0ad148f188322f53d7 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 8 Jan 2025 09:46:42 -0500 Subject: [PATCH 2/3] fix: file path Signed-off-by: Will Johnson --- tests/acceleration/test_acceleration_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 94198d52e..fb84bfed9 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -57,7 +57,7 @@ # for some reason the CI will raise an import error if we try to import # these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_json.json" + os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_small.json" ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), From 3bbe06e8affdd6fd9abebcfda02c9ceeb578edc2 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 8 Jan 2025 09:50:41 -0500 Subject: [PATCH 3/3] fmt? Signed-off-by: Will Johnson --- tests/acceleration/test_acceleration_framework.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index fb84bfed9..80a445304 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -57,7 +57,8 @@ # for some reason the CI will raise an import error if we try to import # these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_small.json" + os.path.dirname(__file__), + "../artifacts/testdata/json/twitter_complaints_small.json", ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__),