From c818b45c8c3f07a3132312b4fdf4991ce6d9972f Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 19 Dec 2024 14:26:45 -0500 Subject: [PATCH] PR changes Signed-off-by: Abhishek --- tests/data/test_data_preprocessing_utils.py | 80 +-------------------- tests/utils/test_config_utils.py | 30 +++++++- tuning/data/data_config.py | 5 +- tuning/data/data_processors.py | 41 ++++++----- tuning/utils/utils.py | 24 ++++--- 5 files changed, 74 insertions(+), 106 deletions(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 50b6a1065..5c79273fe 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -438,11 +438,6 @@ def test_load_dataset_with_datasetconfig_files_folders( @pytest.mark.parametrize( "data_paths, datasetconfigname, builder", [ - ( - [TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], - "text_dataset_input_output_masking", - None, - ), ( [ TWITTER_COMPLAINTS_DATA_DIR_JSON, @@ -453,7 +448,7 @@ def test_load_dataset_with_datasetconfig_files_folders( ), ], ) -def test_load_dataset_with_datasetconfig_files_folders_incorrect_format( +def test_load_dataset_with_datasetconfig_files_folders_incorrect_builder( data_paths, datasetconfigname, builder ): """Ensure that load_dataset with passing combination of files and folders does support mismatch in format""" @@ -1048,79 +1043,6 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( ) -@pytest.mark.parametrize( - "data_config_path, data_path_list", - [ - ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], - ), - ( - DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, - [ - TWITTER_COMPLAINTS_TOKENIZED_JSONL, - TWITTER_COMPLAINTS_TOKENIZED_ARROW, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ], - ), - ( - DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, - [ - TWITTER_COMPLAINTS_TOKENIZED_JSON, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - ], - ), - ( - DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, - [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], - ), - ], -) -def test_process_dataconfig_multiple_files_varied_data_formats( - data_config_path, data_path_list -): - """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" - with open(data_config_path, "r") as f: - yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = data_path_list - datasets_name = yaml_content["datasets"][0]["name"] - - # Modify input_field_name and output_field_name according to dataset - if datasets_name == "text_dataset_input_output_masking": - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "input_field_name": "input", - "output_field_name": "output", - } - - # Modify dataset_text_field and template according to dataset - formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "dataset_text_field": formatted_dataset_field, - "template": template, - } - - with tempfile.NamedTemporaryFile( - "w", delete=False, suffix=".yaml" - ) as temp_yaml_file: - yaml.dump(yaml_content, temp_yaml_file) - temp_yaml_file_path = temp_yaml_file.name - data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises( - ( - AssertionError, - ValueError, - datasets.exceptions.DatasetGenerationCastError, - pyarrow.lib.ArrowInvalid, - AttributeError, - ) - ): - (_, _, _) = _process_dataconfig_file(data_args, tokenizer) - - @pytest.mark.parametrize( "data_args", [ diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py index 1cbbbaaa0..8e29750fb 100644 --- a/tests/utils/test_config_utils.py +++ b/tests/utils/test_config_utils.py @@ -17,10 +17,12 @@ # Standard import base64 +import logging import os import pickle # Third Party +from datasets import Dataset, Features, Value from peft import LoraConfig, PromptTuningConfig import pytest @@ -29,7 +31,7 @@ # Local from tuning.config import peft_config -from tuning.utils import config_utils +from tuning.utils import config_utils, utils def test_get_hf_peft_config_returns_None_for_tuning_config_None(): @@ -232,3 +234,29 @@ def test_get_json_config_can_load_from_envvar(): job_config = config_utils.get_json_config() assert job_config is not None assert job_config["model_name_or_path"] == "foobar" + + +def test_validate_datasets_logs_warnings_on_mismatch(caplog): + """Test that `validate_mergeable_datasets` logs warnings when + datasets have different columns or dtypes.""" + # Create a reference dataset with columns col1:int64 and col2:string + ds1 = Dataset.from_dict( + {"col1": [1, 2], "col2": ["hello", "world"]}, + features=Features({"col1": Value("int64"), "col2": Value("string")}), + ) + + # Create a second dataset with a different column set and a different dtype for col1 + ds2 = Dataset.from_dict( + {"col1": [0.1, 0.2], "col3": ["hi", "there"]}, + features=Features({"col1": Value("float64"), "col3": Value("string")}), + ) + + with caplog.at_level(logging.WARNING): + utils.validate_mergeable_datasets([ds1, ds2]) + + assert ( + "different columns" in caplog.text + ), "Expected a warning about differing columns." + assert ( + "expected int64" in caplog.text + ), "Expected a warning about mismatching column dtypes." diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 06dcfb8e1..b95b996fe 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -32,7 +32,7 @@ class DataHandlerConfig: class DataSetConfig: name: str data_paths: List[str] - builder: Optional[str] = None + builder: Optional[str] = None # Referring to Hugging Face dataset builder sampling: Optional[float] = None data_handlers: Optional[List[DataHandlerConfig]] = None @@ -91,7 +91,8 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: builder = kwargs["builder"] assert isinstance( builder, str - ), f"builder: {builder} should be str with values in (json, text, parquet, arrow)" + ), f"builder should be a string representing a supported \ + Hugging Face dataset builder, but got: {builder}" c.builder = builder if "sampling" in kwargs and kwargs["sampling"] is not None: ratio = kwargs["sampling"] diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 1f793097a..dbce7aedc 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -27,7 +27,7 @@ # Local from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS -from tuning.utils.utils import get_loader_for_filepath, validate_datasets +from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets class DataPreProcessor: @@ -81,13 +81,13 @@ def load_dataset( if (not datafile) and (not datasetconfig): raise ValueError("Either datafile or datasetconfig must be set") - def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): + def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None): """ Helper function to load a dataset using datasets.load_dataset with standardized exception handling. Args: - path: The path argument for load_dataset (could be a directory, file, builder, etc.) + data_path: The path argument for load_dataset (directory, file, pattern, dataset_id) builder: Optional builder to use if provided. data_files: Optional data_files list if loading from files. data_dir: Optional data_dir if loading from a directory with a builder. @@ -101,7 +101,7 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): load_kwargs["data_files"] = data_files # Determine the `path` parameter for load_dataset - load_path = builder if builder else path + load_path = builder if builder else data_path try: return datasets.load_dataset(path=load_path, **load_kwargs) @@ -110,7 +110,11 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): raise e except FileNotFoundError as e: # Handle file/directory not found - context = f"builder {builder}" if builder else f"path {path}" + context = ( + f"path {data_path} with builder {builder}" + if builder + else f"path {data_path}" + ) raise ValueError(f"Data loading failed: invalid {context}.") from e except datasets.exceptions.DatasetGenerationError as e: context = ( @@ -118,7 +122,7 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): if builder and data_dir else f"builder {builder}" if builder - else f"path {path}" + else f"path {data_path}" ) raise ValueError( f"Failed to generate the dataset from the provided {context}." @@ -143,7 +147,7 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): dataset = _load_dataset(builder=builder, data_dir=data_path) else: # Load directly from the directory - dataset = _load_dataset(path=data_path) + dataset = _load_dataset(data_path=data_path) else: # Non-directory (file, pattern, HF dataset name) # If no builder provided, attempt to infer one @@ -158,25 +162,30 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): ) else: # CASE 3: User passes files/folder/pattern/HF_dataset which has no builder - dataset = _load_dataset(path=data_path) + dataset = _load_dataset(data_path=data_path) all_datasets.append(dataset) - # Validate all datasets to have same columns - validate_datasets(all_datasets) + # Logs warning if datasets have different columns + validate_mergeable_datasets(all_datasets) # Concatenate all datasets try: - raw_datasets = ( - datasets.concatenate_datasets(all_datasets) - if len(all_datasets) > 1 - else all_datasets[0] + if len(all_datasets) == 1: + return all_datasets[0] + + raw_datasets = datasets.concatenate_datasets(all_datasets) + logging.info( + "Datasets concatenated from %s .Concatenated dataset columns: %s", + datasetconfig.name, + list(raw_datasets.features.keys()), ) + return raw_datasets + except Exception as e: raise ValueError( - f"An error occurred while concatenating datasets: {e}" + f"An error occurred while concatenating datasets from {datasetconfig.name}: {e}" ) from e - return raw_datasets def _process_dataset_configs( self, dataset_configs: List[DataSetConfig], **extra_kwargs diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py index 4b8a68bdb..8573aac60 100644 --- a/tuning/utils/utils.py +++ b/tuning/utils/utils.py @@ -14,6 +14,7 @@ # Standard import json +import logging import os # Third Party @@ -48,7 +49,7 @@ def load_yaml_or_json(file_path: str) -> dict: return None -def validate_datasets(datasets): +def validate_mergeable_datasets(datasets): """Given list of datasets, validate if all datasets have same type and number of columns.""" if len(datasets) > 1: ref_columns = datasets[0].features @@ -62,15 +63,22 @@ def validate_datasets(datasets): # Check same set of columns if set(ds_column_names) != set(ref_column_names): - raise ValueError( - f"Dataset {i} has different columns: {ds_column_names}. " - f"Expected columns: {ref_column_names}" + logging.warning( + "Dataset %d has different columns: %s. Columns in Dataset 1: %s", + i, + ds_column_names, + ref_column_names, ) # Check column data types for col in ref_column_names: - if ds_column_types[col] != ref_column_types[col]: - raise ValueError( - f"Column '{col}' in dataset {i} has type {ds_column_types[col]}, " - f"expected {ref_column_types[col]}" + if (col in ds_column_types) and ( + ds_column_types[col] != ref_column_types[col] + ): + logging.warning( + "Column '%s' in dataset %d has type %s, expected %s", + col, + i, + ds_column_types[col], + ref_column_types[col], )