From 594694905cff490a480c0e006f32467b2f92cf05 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Fri, 6 Sep 2024 02:39:28 +0530 Subject: [PATCH 1/2] feat: Migrating the trainer controller to python logger (#309) * fix: Migrate tranformer logging to python logging Signed-off-by: Padmanabha V Seshadri * fix: Migrate tranformer logging to python logging Signed-off-by: Padmanabha V Seshadri * fix: Removed unwanted file Signed-off-by: Padmanabha V Seshadri * fix: Log levels obtained from reversing the dictionary Signed-off-by: Padmanabha V Seshadri * fix: Format issues Signed-off-by: Padmanabha V Seshadri * fix: Variable names made meaningful Signed-off-by: Padmanabha V Seshadri * fix: Removed unwanted log line Signed-off-by: Padmanabha V Seshadri * fix: Added name to getLogger Signed-off-by: Padmanabha V Seshadri * fix: Added default logging level to DEBUG Signed-off-by: Padmanabha V Seshadri * fix: Added default logging level to DEBUG Signed-off-by: Padmanabha V Seshadri * fix: Added default logging level to DEBUG Signed-off-by: Padmanabha V Seshadri * fix: Removed setLevel() calls from the packages Signed-off-by: Padmanabha V Seshadri * fix: Format issues resolved Signed-off-by: Padmanabha V Seshadri --------- Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/callback.py | 10 +++++----- .../controllermetrics/trainingstate.py | 3 --- tuning/trainercontroller/operations/hfcontrols.py | 4 +++- tuning/trainercontroller/operations/logcontrol.py | 12 ++++++------ tuning/trainercontroller/operations/operation.py | 6 ++---- tuning/trainercontroller/patience.py | 6 ++++-- 6 files changed, 20 insertions(+), 21 deletions(-) diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index fad1bbf70..a1b3397d7 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -18,6 +18,7 @@ # Standard from typing import Dict, List, Union import inspect +import logging import os import re @@ -29,7 +30,6 @@ TrainerState, TrainingArguments, ) -from transformers.utils import logging import yaml # Local @@ -45,7 +45,7 @@ from tuning.trainercontroller.patience import PatienceControl from tuning.utils.evaluator import MetricUnavailableError, RuleEvaluator -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) # Configuration keys CONTROLLER_METRICS_KEY = "controller_metrics" @@ -66,7 +66,7 @@ DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]} DEFAULT_METRICS = {} DEFAULT_CONFIG = {} -DEFAULT_TRIGGER_LOG_LEVEL = "debug" +DEFAULT_TRIGGER_LOG_LEVEL = "DEBUG" # pylint: disable=too-many-instance-attributes class TrainerControllerCallback(TrainerCallback): @@ -305,7 +305,7 @@ def on_init_end( kwargs["state"] = state kwargs["control"] = control - log_levels = logging.get_log_levels_dict() + log_levels = dict((value, key) for key, value in logging._levelToName.items()) # Check if there any metrics listed in the configuration if ( CONTROLLER_METRICS_KEY not in self.trainer_controller_config @@ -407,7 +407,7 @@ def on_init_end( control.config = controller[CONTROLLER_CONFIG_KEY] config_log_level_str = control.config.get( CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str - ) + ).upper() if config_log_level_str not in log_levels: logger.warning( "Incorrect trigger log-level [%s] specified in the config." diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py index 8dc276339..06da4035a 100644 --- a/tuning/trainercontroller/controllermetrics/trainingstate.py +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -21,13 +21,10 @@ # Third Party from transformers import TrainerState -from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler -logger = logging.get_logger(__name__) - class TrainingState(MetricHandler): """Implements the controller metric which exposes the trainer state""" diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py index 0548b4c12..90988c16a 100644 --- a/tuning/trainercontroller/operations/hfcontrols.py +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -10,6 +10,8 @@ # Local from .operation import Operation +logger = logging.getLogger(__name__) + class HFControls(Operation): """Implements the control actions for the HuggingFace controls in @@ -37,7 +39,7 @@ def control_action(self, control: TrainerControl, **kwargs): control: TrainerControl. Data class for controls. kwargs: List of arguments (key, value)-pairs """ - logging.debug("Arguments passed to control_action: %s", repr(kwargs)) + logger.debug("Arguments passed to control_action: %s", repr(kwargs)) frame_info = inspect.currentframe().f_back arg_values = inspect.getargvalues(frame_info) setattr(control, arg_values.locals["action"], True) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 385de3b4d..eabb420c9 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -1,12 +1,13 @@ +# Standard +import logging + # Third Party from transformers import TrainingArguments -from transformers.utils import logging # Local from .operation import Operation -logger = logging.get_logger(__name__) -logger.setLevel(level=logging.DEBUG) +logger = logging.getLogger(__name__) class LogControl(Operation): @@ -20,12 +21,11 @@ def __init__(self, log_format: str, log_level: str, **kwargs): Args: kwargs: List of arguments (key, value)-pairs """ - log_levels = logging.get_log_levels_dict() - if log_level not in log_levels: + self.log_level = getattr(logging, log_level.upper(), None) + if not isinstance(self.log_level, int): raise ValueError( "Specified log_level [%s] is invalid for LogControl" % (log_level) ) - self.log_level = log_levels[log_level] self.log_format = log_format super().__init__(**kwargs) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 70805a015..f6b4884fc 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -1,12 +1,10 @@ # Standard import abc import inspect +import logging import re -# Third Party -from transformers.utils import logging - -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) class Operation(metaclass=abc.ABCMeta): diff --git a/tuning/trainercontroller/patience.py b/tuning/trainercontroller/patience.py index ecdb0699a..bda91363c 100644 --- a/tuning/trainercontroller/patience.py +++ b/tuning/trainercontroller/patience.py @@ -31,6 +31,8 @@ # will be exceeded afer the fifth event. MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure" +logger = logging.getLogger(__name__) + class PatienceControl: """Implements the patience control for every rule""" @@ -49,7 +51,7 @@ def should_tolerate( elif self._mode == MODE_RESET_ON_FAILURE: self._patience_counter = 0 if self._patience_counter <= self._patience_threshold: - logging.debug( + logger.debug( "Control {} triggered on event {}: " "Enforcing patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( @@ -60,7 +62,7 @@ def should_tolerate( ) ) return True - logging.debug( + logger.debug( "Control {} triggered on event {}: " "Exceeded patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( From 32b751c6a40c4cd5420e4df59bf3881ccc8cf00f Mon Sep 17 00:00:00 2001 From: Hari Date: Tue, 10 Sep 2024 20:54:34 +0530 Subject: [PATCH 2/2] fix: remove fire for handling CLI args (#324) Signed-off-by: Mehant Kammakomati Signed-off-by: Harikrishnan Balagopal Signed-off-by: Anh Uong Co-authored-by: Mehant Kammakomati --- pyproject.toml | 1 - tuning/sft_trainer.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e31192470..2675f49b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "trl>=0.9.3,<1.0", "peft>=0.8.0,<0.13", "datasets>=2.15.0,<3.0", -"fire>=0.5.0,<1.0", "simpleeval>=0.9.13,<1.0", ] diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index bc1937c32..2ab8f7de0 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -37,7 +37,6 @@ ) from transformers.utils import is_accelerate_available from trl import SFTConfig, SFTTrainer -import fire import transformers # Local @@ -515,7 +514,7 @@ def parse_arguments(parser, json_config=None): ) -def main(**kwargs): # pylint: disable=unused-argument +def main(): parser = get_parser() logger = logging.getLogger() job_config = get_json_config() @@ -636,4 +635,4 @@ def main(**kwargs): # pylint: disable=unused-argument if __name__ == "__main__": - fire.Fire(main) + main()