From aef77c8ef852a2f48e1746cf5ae33423ac9adba9 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Fri, 23 Aug 2024 12:26:05 -0600 Subject: [PATCH] refactor set log level - separate out the train args from python logger Signed-off-by: Anh Uong --- tuning/sft_trainer.py | 8 ++------ tuning/utils/logging.py | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b5e6cb62e..739e1ba79 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -62,7 +62,7 @@ USER_ERROR_EXIT_CODE, write_termination_log, ) -from tuning.utils.logging import set_log_level +from tuning.utils.logging import set_log_level, set_python_log_level from tuning.utils.preprocessing_utils import ( format_dataset, get_data_collator, @@ -377,11 +377,7 @@ def save(path: str, trainer: SFTTrainer, log_level="WARNING"): Optional threshold to set save save logger to, default warning. """ logger = logging.getLogger("sft_trainer_save") - # default value from TrainingArguments - if log_level == "passive": - log_level = "WARNING" - - logger.setLevel(log_level.upper()) + set_python_log_level(log_level, logger) if not os.path.exists(path): os.makedirs(path, exist_ok=True) diff --git a/tuning/utils/logging.py b/tuning/utils/logging.py index 1f1b6c73e..5c12bd01b 100644 --- a/tuning/utils/logging.py +++ b/tuning/utils/logging.py @@ -53,12 +53,24 @@ def set_log_level(train_args, logger_name=None): else os.environ.get("TRANSFORMERS_VERBOSITY") ) - logging.basicConfig( - format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper() - ) - + train_logger = logging.getLogger() if logger_name: train_logger = logging.getLogger(logger_name) - else: - train_logger = logging.getLogger() + + set_python_log_level(log_level, train_logger) + set_python_log_level(log_level) return train_args, train_logger + + +def set_python_log_level(log_level=None, logger=None): + # Configure Python native logger + # If CLI arg is passed, assign same log level to python native logger + if not log_level: + log_level = os.environ.get("LOG_LEVEL", "WARNING") + + if logger: + logger.setLevel(log_level.upper()) + else: + logging.basicConfig( + format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper() + )