Skip to content

Commit

Permalink
PR changes
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Dec 19, 2024
1 parent ba0e543 commit 62efeec
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
6 changes: 3 additions & 3 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# Local
from tuning.utils.utils import load_yaml_or_json

logger = logging.getLogger(__name__)


@dataclass
class DataHandlerConfig:
Expand Down Expand Up @@ -82,9 +84,7 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
assert isinstance(p, str), f"path {p} should be of the type string"
if not os.path.isabs(p):
_p = os.path.abspath(p)
logging.warning(
" Provided path %s is not absolute changing it to %s", p, _p
)
logger.warning(" Provided path %s is not absolute changing it to %s", p, _p)
p = _p
c.data_paths.append(p)
if "builder" in kwargs and kwargs["builder"] is not None:
Expand Down
26 changes: 14 additions & 12 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS
from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets

logger = logging.getLogger(__name__)


class DataPreProcessor:

Expand All @@ -54,11 +56,11 @@ def register_data_handler(self, name: str, func: Callable):
if not isinstance(name, str) or not callable(func):
raise ValueError("Handlers should be of type Dict, str to callable")
if name in self.registered_handlers:
logging.warning(
logger.warning(
"Handler name '%s' already exists and will be overwritten", name
)
self.registered_handlers[name] = func
logging.info("Registered new handler %s", name)
logger.info("Registered new handler %s", name)

def register_data_handlers(self, handlers: Dict[str, Callable]):
if handlers is None:
Expand Down Expand Up @@ -175,7 +177,7 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None):
return all_datasets[0]

raw_datasets = datasets.concatenate_datasets(all_datasets)
logging.info(
logger.info(
"Datasets concatenated from %s .Concatenated dataset columns: %s",
datasetconfig.name,
list(raw_datasets.features.keys()),
Expand Down Expand Up @@ -207,25 +209,25 @@ def _process_dataset_configs(
if sum(p for p in sampling_probabilities) != 1:
raise ValueError("Sampling probabilities don't sum to 1")
sample_datasets = True
logging.info(
logger.info(
"Sampling ratios are specified; given datasets will be interleaved."
)
else:
logging.info(
logger.info(
"Sampling is not specified; if multiple datasets are provided,"
" the given datasets will be concatenated."
)
sample_datasets = False

logging.info("Starting DataPreProcessor...")
logger.info("Starting DataPreProcessor...")
# Now Iterate over the multiple datasets provided to us to process
for d in dataset_configs:
logging.info("Loading %s", d.name)
logger.info("Loading %s", d.name)

# In future the streaming etc go as kwargs of this function
raw_dataset = self.load_dataset(d, splitName)

logging.info("Loaded raw dataset : %s", str(raw_dataset))
logger.info("Loaded raw dataset : %s", str(raw_dataset))

raw_datasets = DatasetDict()

Expand Down Expand Up @@ -266,7 +268,7 @@ def _process_dataset_configs(

kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs)

logging.info("Applying Handler: %s Args: %s", data_handler, kwargs)
logger.info("Applying Handler: %s Args: %s", data_handler, kwargs)

raw_datasets = raw_datasets.map(handler, **kwargs)

Expand All @@ -285,7 +287,7 @@ def _process_dataset_configs(
if sample_datasets:
strategy = self.processor_config.sampling_stopping_strategy
seed = self.processor_config.sampling_seed
logging.info(
logger.info(
"Interleaving datasets: strategy[%s] seed[%d] probabilities[%s]",
strategy,
seed,
Expand Down Expand Up @@ -316,7 +318,7 @@ def process_dataset_configs(

if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
logging.info("Processing data on rank 0...")
logger.info("Processing data on rank 0...")
train_dataset = self._process_dataset_configs(dataset_configs, **kwargs)
else:
train_dataset = None
Expand All @@ -329,7 +331,7 @@ def process_dataset_configs(
torch.distributed.broadcast_object_list(to_share, src=0)
train_dataset = to_share[0]
else:
logging.info("Processing data...")
logger.info("Processing data...")
train_dataset = self._process_dataset_configs(dataset_configs, **kwargs)

return train_dataset
Expand Down
6 changes: 4 additions & 2 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from tuning.data.data_preprocessing_utils import get_data_collator
from tuning.data.data_processors import get_datapreprocessor

logger = logging.getLogger(__name__)

# In future we may make the fields configurable
DEFAULT_INPUT_COLUMN = "input"
DEFAULT_OUTPUT_COLUMN = "output"
Expand Down Expand Up @@ -320,9 +322,9 @@ def process_dataargs(
"""

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logging.info("Max sequence length is %s", max_seq_length)
logger.info("Max sequence length is %s", max_seq_length)
if train_args.max_seq_length > tokenizer.model_max_length:
logging.warning(
logger.warning(
"max_seq_length %s exceeds tokenizer.model_max_length \
%s, using tokenizer.model_max_length %s",
train_args.max_seq_length,
Expand Down

0 comments on commit 62efeec

Please sign in to comment.