Skip to content

Commit

Permalink
PR changes
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Dec 19, 2024
1 parent 4bf1a31 commit c818b45
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 106 deletions.
80 changes: 1 addition & 79 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,6 @@ def test_load_dataset_with_datasetconfig_files_folders(
@pytest.mark.parametrize(
"data_paths, datasetconfigname, builder",
[
(
[TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
"text_dataset_input_output_masking",
None,
),
(
[
TWITTER_COMPLAINTS_DATA_DIR_JSON,
Expand All @@ -453,7 +448,7 @@ def test_load_dataset_with_datasetconfig_files_folders(
),
],
)
def test_load_dataset_with_datasetconfig_files_folders_incorrect_format(
def test_load_dataset_with_datasetconfig_files_folders_incorrect_builder(
data_paths, datasetconfigname, builder
):
"""Ensure that load_dataset with passing combination of files and folders does support mismatch in format"""
Expand Down Expand Up @@ -1048,79 +1043,6 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
],
)
def test_process_dataconfig_multiple_files_varied_data_formats(
data_config_path, data_path_list
):
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = data_path_list
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(
(
AssertionError,
ValueError,
datasets.exceptions.DatasetGenerationCastError,
pyarrow.lib.ArrowInvalid,
AttributeError,
)
):
(_, _, _) = _process_dataconfig_file(data_args, tokenizer)


@pytest.mark.parametrize(
"data_args",
[
Expand Down
30 changes: 29 additions & 1 deletion tests/utils/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

# Standard
import base64
import logging
import os
import pickle

# Third Party
from datasets import Dataset, Features, Value
from peft import LoraConfig, PromptTuningConfig
import pytest

Expand All @@ -29,7 +31,7 @@

# Local
from tuning.config import peft_config
from tuning.utils import config_utils
from tuning.utils import config_utils, utils


def test_get_hf_peft_config_returns_None_for_tuning_config_None():
Expand Down Expand Up @@ -232,3 +234,29 @@ def test_get_json_config_can_load_from_envvar():
job_config = config_utils.get_json_config()
assert job_config is not None
assert job_config["model_name_or_path"] == "foobar"


def test_validate_datasets_logs_warnings_on_mismatch(caplog):
"""Test that `validate_mergeable_datasets` logs warnings when
datasets have different columns or dtypes."""
# Create a reference dataset with columns col1:int64 and col2:string
ds1 = Dataset.from_dict(
{"col1": [1, 2], "col2": ["hello", "world"]},
features=Features({"col1": Value("int64"), "col2": Value("string")}),
)

# Create a second dataset with a different column set and a different dtype for col1
ds2 = Dataset.from_dict(
{"col1": [0.1, 0.2], "col3": ["hi", "there"]},
features=Features({"col1": Value("float64"), "col3": Value("string")}),
)

with caplog.at_level(logging.WARNING):
utils.validate_mergeable_datasets([ds1, ds2])

assert (
"different columns" in caplog.text
), "Expected a warning about differing columns."
assert (
"expected int64" in caplog.text
), "Expected a warning about mismatching column dtypes."
5 changes: 3 additions & 2 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataHandlerConfig:
class DataSetConfig:
name: str
data_paths: List[str]
builder: Optional[str] = None
builder: Optional[str] = None # Referring to Hugging Face dataset builder
sampling: Optional[float] = None
data_handlers: Optional[List[DataHandlerConfig]] = None

Expand Down Expand Up @@ -91,7 +91,8 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
builder = kwargs["builder"]
assert isinstance(
builder, str
), f"builder: {builder} should be str with values in (json, text, parquet, arrow)"
), f"builder should be a string representing a supported \
Hugging Face dataset builder, but got: {builder}"
c.builder = builder
if "sampling" in kwargs and kwargs["sampling"] is not None:
ratio = kwargs["sampling"]
Expand Down
41 changes: 25 additions & 16 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# Local
from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig
from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS
from tuning.utils.utils import get_loader_for_filepath, validate_datasets
from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets


class DataPreProcessor:
Expand Down Expand Up @@ -81,13 +81,13 @@ def load_dataset(
if (not datafile) and (not datasetconfig):
raise ValueError("Either datafile or datasetconfig must be set")

def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None):
"""
Helper function to load a dataset using datasets.load_dataset
with standardized exception handling.
Args:
path: The path argument for load_dataset (could be a directory, file, builder, etc.)
data_path: The path argument for load_dataset (directory, file, pattern, dataset_id)
builder: Optional builder to use if provided.
data_files: Optional data_files list if loading from files.
data_dir: Optional data_dir if loading from a directory with a builder.
Expand All @@ -101,7 +101,7 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
load_kwargs["data_files"] = data_files

# Determine the `path` parameter for load_dataset
load_path = builder if builder else path
load_path = builder if builder else data_path

try:
return datasets.load_dataset(path=load_path, **load_kwargs)
Expand All @@ -110,15 +110,19 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
raise e
except FileNotFoundError as e:
# Handle file/directory not found
context = f"builder {builder}" if builder else f"path {path}"
context = (
f"path {data_path} with builder {builder}"
if builder
else f"path {data_path}"
)
raise ValueError(f"Data loading failed: invalid {context}.") from e
except datasets.exceptions.DatasetGenerationError as e:
context = (
f"builder {builder} and data_dir {data_dir}"
if builder and data_dir
else f"builder {builder}"
if builder
else f"path {path}"
else f"path {data_path}"
)
raise ValueError(
f"Failed to generate the dataset from the provided {context}."
Expand All @@ -143,7 +147,7 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
dataset = _load_dataset(builder=builder, data_dir=data_path)
else:
# Load directly from the directory
dataset = _load_dataset(path=data_path)
dataset = _load_dataset(data_path=data_path)
else:
# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
Expand All @@ -158,25 +162,30 @@ def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
)
else:
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
dataset = _load_dataset(path=data_path)
dataset = _load_dataset(data_path=data_path)

all_datasets.append(dataset)

# Validate all datasets to have same columns
validate_datasets(all_datasets)
# Logs warning if datasets have different columns
validate_mergeable_datasets(all_datasets)

# Concatenate all datasets
try:
raw_datasets = (
datasets.concatenate_datasets(all_datasets)
if len(all_datasets) > 1
else all_datasets[0]
if len(all_datasets) == 1:
return all_datasets[0]

raw_datasets = datasets.concatenate_datasets(all_datasets)
logging.info(
"Datasets concatenated from %s .Concatenated dataset columns: %s",
datasetconfig.name,
list(raw_datasets.features.keys()),
)
return raw_datasets

except Exception as e:
raise ValueError(
f"An error occurred while concatenating datasets: {e}"
f"An error occurred while concatenating datasets from {datasetconfig.name}: {e}"
) from e
return raw_datasets

def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig], **extra_kwargs
Expand Down
24 changes: 16 additions & 8 deletions tuning/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
import json
import logging
import os

# Third Party
Expand Down Expand Up @@ -48,7 +49,7 @@ def load_yaml_or_json(file_path: str) -> dict:
return None


def validate_datasets(datasets):
def validate_mergeable_datasets(datasets):
"""Given list of datasets, validate if all datasets have same type and number of columns."""
if len(datasets) > 1:
ref_columns = datasets[0].features
Expand All @@ -62,15 +63,22 @@ def validate_datasets(datasets):

# Check same set of columns
if set(ds_column_names) != set(ref_column_names):
raise ValueError(
f"Dataset {i} has different columns: {ds_column_names}. "
f"Expected columns: {ref_column_names}"
logging.warning(
"Dataset %d has different columns: %s. Columns in Dataset 1: %s",
i,
ds_column_names,
ref_column_names,
)

# Check column data types
for col in ref_column_names:
if ds_column_types[col] != ref_column_types[col]:
raise ValueError(
f"Column '{col}' in dataset {i} has type {ds_column_types[col]}, "
f"expected {ref_column_types[col]}"
if (col in ds_column_types) and (
ds_column_types[col] != ref_column_types[col]
):
logging.warning(
"Column '%s' in dataset %d has type %s, expected %s",
col,
i,
ds_column_types[col],
ref_column_types[col],
)

0 comments on commit c818b45

Please sign in to comment.