Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: set log level to separate out train args #314

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.logging import set_log_level, set_python_log_level
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -377,11 +377,7 @@ def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Optional threshold to set save save logger to, default warning.
"""
logger = logging.getLogger("sft_trainer_save")
# default value from TrainingArguments
if log_level == "passive":
log_level = "WARNING"

logger.setLevel(log_level.upper())
set_python_log_level(log_level, logger)

if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
Expand Down
24 changes: 18 additions & 6 deletions tuning/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,24 @@ def set_log_level(train_args, logger_name=None):
else os.environ.get("TRANSFORMERS_VERBOSITY")
)

logging.basicConfig(
format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper()
)

train_logger = logging.getLogger()
if logger_name:
train_logger = logging.getLogger(logger_name)
else:
train_logger = logging.getLogger()

set_python_log_level(log_level, train_logger)
set_python_log_level(log_level)
Comment on lines +56 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have a generic function set_python_log_level to set log level of any logger passed to it from any module of the code. Thank you!

Regarding function set_log_level, according to me, it is called to set root log level in the code based on either CLI using train_argsor ENV variable. Hence, I believe, if we don't want to call set_python_log_level twice in the function set_log_level, we can modify the code in function set_log_level as below:

set_python_log_level(log_level)
train_logger = logging.getLogger()
if logger_name:
    train_logger = logging.getLogger(logger_name)

return train_args, train_logger

Call set_python_log_level(log_level), will set the root log level and train_logger with or without logger_name, will inherit the root log level.

return train_args, train_logger


def set_python_log_level(log_level=None, logger=None):
# Configure Python native logger
# If CLI arg is passed, assign same log level to python native logger
if not log_level:
log_level = os.environ.get("LOG_LEVEL", "WARNING")

if logger:
logger.setLevel(log_level.upper())
else:
logging.basicConfig(
format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper()
)
Loading