Skip to content

Commit

Permalink
tests: Added test case for load_dataset function of class HFBasedData…
Browse files Browse the repository at this point in the history
…PreProcessor

Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Nov 23, 2024
1 parent 8aace79 commit 6cad504
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

# Local
from tuning.config import configs
from tuning.data.data_config import DataLoaderConfig, DataSetConfig
from tuning.data.data_preprocessing_utils import (
combine_sequence,
get_data_collator,
validate_data_args,
)
from tuning.data.data_processors import get_dataprocessor
from tuning.data.setup_dataprocessor import is_pretokenized_dataset, process_dataargs


Expand Down Expand Up @@ -60,6 +62,110 @@ def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
assert comb_seq == expected_res


@pytest.mark.parametrize(
"datafile, column_names",
[
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
set(["ID", "Label", "input", "output"]),
),
(
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
set(
[
"Tweet text",
"ID",
"Label",
"text_label",
"output",
"input_ids",
"labels",
]
),
),
(
TWITTER_COMPLAINTS_DATA_JSONL,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
),
],
)
def test_load_dataset_with_datafile(datafile, column_names):
"""Ensure that both dataset is loaded with datafile."""
processor = get_dataprocessor(dataloaderconfig=DataLoaderConfig(), tokenizer=None)
load_dataset = processor.load_dataset(
datasetconfig=None, splitName="train", datafile=datafile
)
assert set(load_dataset.column_names) == column_names


@pytest.mark.parametrize(
"datafile, column_names, datasetconfigname",
[
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
set(["ID", "Label", "input", "output"]),
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
set(
[
"Tweet text",
"ID",
"Label",
"text_label",
"output",
"input_ids",
"labels",
]
),
"pretokenized_dataset",
),
(
TWITTER_COMPLAINTS_DATA_JSONL,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
"apply_custom_data_template",
),
],
)
def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname):
"""Ensure that both dataset is loaded with datafile."""
datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile])
processor = get_dataprocessor(dataloaderconfig=DataLoaderConfig(), tokenizer=None)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
)
assert set(load_dataset.column_names) == column_names


@pytest.mark.parametrize(
"datafile, datasetconfigname",
[
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
"text_dataset_input_output_masking",
),
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, "pretokenized_dataset"),
(TWITTER_COMPLAINTS_DATA_JSONL, "apply_custom_data_template"),
],
)
def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
"""Ensure that both datasetconfig and datafile cannot be passed."""
datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile])
processor = get_dataprocessor(dataloaderconfig=DataLoaderConfig(), tokenizer=None)
with pytest.raises(ValueError):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=datafile
)


def test_load_dataset_without_dataconfig_and_datafile():
"""Ensure that both datasetconfig and datafile cannot be None."""
processor = get_dataprocessor(dataloaderconfig=DataLoaderConfig(), tokenizer=None)
with pytest.raises(ValueError):
processor.load_dataset(datasetconfig=None, splitName="train", datafile=None)


@pytest.mark.parametrize(
"data, result",
[
Expand Down

0 comments on commit 6cad504

Please sign in to comment.