diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index b95b996fe..0c5521baf 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -21,6 +21,8 @@ # Local from tuning.utils.utils import load_yaml_or_json +logger = logging.getLogger(__name__) + @dataclass class DataHandlerConfig: @@ -82,9 +84,7 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: assert isinstance(p, str), f"path {p} should be of the type string" if not os.path.isabs(p): _p = os.path.abspath(p) - logging.warning( - " Provided path %s is not absolute changing it to %s", p, _p - ) + logger.warning(" Provided path %s is not absolute changing it to %s", p, _p) p = _p c.data_paths.append(p) if "builder" in kwargs and kwargs["builder"] is not None: diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index dbce7aedc..170bc2a81 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -29,6 +29,8 @@ from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets +logger = logging.getLogger(__name__) + class DataPreProcessor: @@ -54,11 +56,11 @@ def register_data_handler(self, name: str, func: Callable): if not isinstance(name, str) or not callable(func): raise ValueError("Handlers should be of type Dict, str to callable") if name in self.registered_handlers: - logging.warning( + logger.warning( "Handler name '%s' already exists and will be overwritten", name ) self.registered_handlers[name] = func - logging.info("Registered new handler %s", name) + logger.info("Registered new handler %s", name) def register_data_handlers(self, handlers: Dict[str, Callable]): if handlers is None: @@ -175,7 +177,7 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None): return all_datasets[0] raw_datasets = datasets.concatenate_datasets(all_datasets) - logging.info( + logger.info( "Datasets concatenated from %s .Concatenated dataset columns: %s", datasetconfig.name, list(raw_datasets.features.keys()), @@ -207,25 +209,25 @@ def _process_dataset_configs( if sum(p for p in sampling_probabilities) != 1: raise ValueError("Sampling probabilities don't sum to 1") sample_datasets = True - logging.info( + logger.info( "Sampling ratios are specified; given datasets will be interleaved." ) else: - logging.info( + logger.info( "Sampling is not specified; if multiple datasets are provided," " the given datasets will be concatenated." ) sample_datasets = False - logging.info("Starting DataPreProcessor...") + logger.info("Starting DataPreProcessor...") # Now Iterate over the multiple datasets provided to us to process for d in dataset_configs: - logging.info("Loading %s", d.name) + logger.info("Loading %s", d.name) # In future the streaming etc go as kwargs of this function raw_dataset = self.load_dataset(d, splitName) - logging.info("Loaded raw dataset : %s", str(raw_dataset)) + logger.info("Loaded raw dataset : %s", str(raw_dataset)) raw_datasets = DatasetDict() @@ -266,7 +268,7 @@ def _process_dataset_configs( kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs) - logging.info("Applying Handler: %s Args: %s", data_handler, kwargs) + logger.info("Applying Handler: %s Args: %s", data_handler, kwargs) raw_datasets = raw_datasets.map(handler, **kwargs) @@ -285,7 +287,7 @@ def _process_dataset_configs( if sample_datasets: strategy = self.processor_config.sampling_stopping_strategy seed = self.processor_config.sampling_seed - logging.info( + logger.info( "Interleaving datasets: strategy[%s] seed[%d] probabilities[%s]", strategy, seed, @@ -316,7 +318,7 @@ def process_dataset_configs( if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: - logging.info("Processing data on rank 0...") + logger.info("Processing data on rank 0...") train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) else: train_dataset = None @@ -329,7 +331,7 @@ def process_dataset_configs( torch.distributed.broadcast_object_list(to_share, src=0) train_dataset = to_share[0] else: - logging.info("Processing data...") + logger.info("Processing data...") train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) return train_dataset diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 2bae18bb7..b6f09c323 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -33,6 +33,8 @@ from tuning.data.data_preprocessing_utils import get_data_collator from tuning.data.data_processors import get_datapreprocessor +logger = logging.getLogger(__name__) + # In future we may make the fields configurable DEFAULT_INPUT_COLUMN = "input" DEFAULT_OUTPUT_COLUMN = "output" @@ -320,9 +322,9 @@ def process_dataargs( """ max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) - logging.info("Max sequence length is %s", max_seq_length) + logger.info("Max sequence length is %s", max_seq_length) if train_args.max_seq_length > tokenizer.model_max_length: - logging.warning( + logger.warning( "max_seq_length %s exceeds tokenizer.model_max_length \ %s, using tokenizer.model_max_length %s", train_args.max_seq_length,