Skip to content

Commit

Permalink
Merge branch 'main' into add_protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
willmj authored Sep 10, 2024
2 parents 234c0e7 + 32b751c commit dea2238
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 25 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [
"peft>=0.8.0,<0.13",
"protobuf>=5.28.0,<6.0.0",
"datasets>=2.15.0,<3.0",
"fire>=0.5.0,<1.0",
"simpleeval>=0.9.13,<1.0",
]

Expand Down
5 changes: 2 additions & 3 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from transformers.utils import is_accelerate_available
from trl import SFTConfig, SFTTrainer
import fire
import transformers

# Local
Expand Down Expand Up @@ -515,7 +514,7 @@ def parse_arguments(parser, json_config=None):
)


def main(**kwargs): # pylint: disable=unused-argument
def main():
parser = get_parser()
logger = logging.getLogger()
job_config = get_json_config()
Expand Down Expand Up @@ -636,4 +635,4 @@ def main(**kwargs): # pylint: disable=unused-argument


if __name__ == "__main__":
fire.Fire(main)
main()
10 changes: 5 additions & 5 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
from typing import Dict, List, Union
import inspect
import logging
import os
import re

Expand All @@ -29,7 +30,6 @@
TrainerState,
TrainingArguments,
)
from transformers.utils import logging
import yaml

# Local
Expand All @@ -45,7 +45,7 @@
from tuning.trainercontroller.patience import PatienceControl
from tuning.utils.evaluator import MetricUnavailableError, RuleEvaluator

logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)

# Configuration keys
CONTROLLER_METRICS_KEY = "controller_metrics"
Expand All @@ -66,7 +66,7 @@
DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]}
DEFAULT_METRICS = {}
DEFAULT_CONFIG = {}
DEFAULT_TRIGGER_LOG_LEVEL = "debug"
DEFAULT_TRIGGER_LOG_LEVEL = "DEBUG"

# pylint: disable=too-many-instance-attributes
class TrainerControllerCallback(TrainerCallback):
Expand Down Expand Up @@ -305,7 +305,7 @@ def on_init_end(
kwargs["state"] = state
kwargs["control"] = control

log_levels = logging.get_log_levels_dict()
log_levels = dict((value, key) for key, value in logging._levelToName.items())
# Check if there any metrics listed in the configuration
if (
CONTROLLER_METRICS_KEY not in self.trainer_controller_config
Expand Down Expand Up @@ -407,7 +407,7 @@ def on_init_end(
control.config = controller[CONTROLLER_CONFIG_KEY]
config_log_level_str = control.config.get(
CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str
)
).upper()
if config_log_level_str not in log_levels:
logger.warning(
"Incorrect trigger log-level [%s] specified in the config."
Expand Down
3 changes: 0 additions & 3 deletions tuning/trainercontroller/controllermetrics/trainingstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@

# Third Party
from transformers import TrainerState
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)


class TrainingState(MetricHandler):
"""Implements the controller metric which exposes the trainer state"""
Expand Down
4 changes: 3 additions & 1 deletion tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# Local
from .operation import Operation

logger = logging.getLogger(__name__)


class HFControls(Operation):
"""Implements the control actions for the HuggingFace controls in
Expand Down Expand Up @@ -37,7 +39,7 @@ def control_action(self, control: TrainerControl, **kwargs):
control: TrainerControl. Data class for controls.
kwargs: List of arguments (key, value)-pairs
"""
logging.debug("Arguments passed to control_action: %s", repr(kwargs))
logger.debug("Arguments passed to control_action: %s", repr(kwargs))
frame_info = inspect.currentframe().f_back
arg_values = inspect.getargvalues(frame_info)
setattr(control, arg_values.locals["action"], True)
12 changes: 6 additions & 6 deletions tuning/trainercontroller/operations/logcontrol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Standard
import logging

# Third Party
from transformers import TrainingArguments
from transformers.utils import logging

# Local
from .operation import Operation

logger = logging.get_logger(__name__)
logger.setLevel(level=logging.DEBUG)
logger = logging.getLogger(__name__)


class LogControl(Operation):
Expand All @@ -20,12 +21,11 @@ def __init__(self, log_format: str, log_level: str, **kwargs):
Args:
kwargs: List of arguments (key, value)-pairs
"""
log_levels = logging.get_log_levels_dict()
if log_level not in log_levels:
self.log_level = getattr(logging, log_level.upper(), None)
if not isinstance(self.log_level, int):
raise ValueError(
"Specified log_level [%s] is invalid for LogControl" % (log_level)
)
self.log_level = log_levels[log_level]
self.log_format = log_format
super().__init__(**kwargs)

Expand Down
6 changes: 2 additions & 4 deletions tuning/trainercontroller/operations/operation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Standard
import abc
import inspect
import logging
import re

# Third Party
from transformers.utils import logging

logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)


class Operation(metaclass=abc.ABCMeta):
Expand Down
6 changes: 4 additions & 2 deletions tuning/trainercontroller/patience.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# will be exceeded afer the fifth event.
MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure"

logger = logging.getLogger(__name__)


class PatienceControl:
"""Implements the patience control for every rule"""
Expand All @@ -49,7 +51,7 @@ def should_tolerate(
elif self._mode == MODE_RESET_ON_FAILURE:
self._patience_counter = 0
if self._patience_counter <= self._patience_threshold:
logging.debug(
logger.debug(
"Control {} triggered on event {}: "
"Enforcing patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand All @@ -60,7 +62,7 @@ def should_tolerate(
)
)
return True
logging.debug(
logger.debug(
"Control {} triggered on event {}: "
"Exceeded patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand Down

0 comments on commit dea2238

Please sign in to comment.