Skip to content

Commit

Permalink
Optimize load_dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Dec 19, 2024
1 parent d8460d7 commit 26e522e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 111 deletions.
4 changes: 2 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def test_malformatted_data():
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = MALFORMATTED_DATA

with pytest.raises(DatasetGenerationError):
with pytest.raises((DatasetGenerationError, ValueError)):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


Expand All @@ -1143,7 +1143,7 @@ def test_empty_data():
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = EMPTY_DATA

with pytest.raises(DatasetGenerationError):
with pytest.raises((DatasetGenerationError, ValueError)):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


Expand Down
198 changes: 89 additions & 109 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,122 +81,102 @@ def load_dataset(
if (not datafile) and (not datasetconfig):
raise ValueError("Either datafile or datasetconfig must be set")

if datafile:
files = [datafile]
loader = get_loader_for_filepath(file_path=datafile)
if loader in (None, ""):
raise ValueError(f"data path is invalid [{', '.join(files)}]")
def _load_dataset(path=None, builder=None, data_files=None, data_dir=None):
"""
Helper function to load a dataset using datasets.load_dataset
with standardized exception handling.
Args:
path: The path argument for load_dataset (could be a directory, file, builder, etc.)
builder: Optional builder to use if provided.
data_files: Optional data_files list if loading from files.
data_dir: Optional data_dir if loading from a directory with a builder.
Returns: dataset
"""

load_kwargs = {**kwargs, "split": splitName}
if data_dir is not None:
load_kwargs["data_dir"] = data_dir
if data_files is not None:
load_kwargs["data_files"] = data_files

# Determine the `path` parameter for load_dataset
load_path = builder if builder else path

try:
return datasets.load_dataset(
loader, data_files=files, split=splitName, **kwargs
)
return datasets.load_dataset(path=load_path, **load_kwargs)
except DatasetNotFoundError as e:
# Reraise with a more context-specific message if needed
raise e
except FileNotFoundError as e:
raise ValueError(f"data path is invalid [{', '.join(files)}]") from e
elif datasetconfig:
data_paths = datasetconfig.data_paths
all_datasets = []
for _, data_path in enumerate(data_paths):
builder = datasetconfig.builder
# CASE 1: User passes directory
if os.path.isdir(data_path): # Checks if path exists and isdir
if builder:
# From given directory takes datafiles with builder
try:
dataset = datasets.load_dataset(
path=builder,
data_dir=data_path,
split=splitName,
**kwargs,
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(
f"data path of directory {data_path} \
with builder {builder} is invalid."
) from e
except datasets.exceptions.DatasetGenerationError as e:
raise ValueError(
f"failed to generate the dataset from \
the provided builder {builder} and directory {data_path}."
) from e
else:
# Pass directory to HF directly
try:
dataset = datasets.load_dataset(
path=data_path, split=splitName, **kwargs
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(
f"data path of directory {data_path} is invalid."
) from e
except datasets.exceptions.DatasetGenerationError as e:
raise ValueError(
f"failed to generate the dataset \
from the provided directory {data_path}."
) from e
else:
# CASE OF NON-DIRECTORY
# If user did not pass builder
if builder is None:
builder = get_loader_for_filepath(data_path)

# CASE 2: Files passed with builder
if builder:
try:
dataset = datasets.load_dataset(
path=builder,
data_files=[data_path],
split=splitName,
**kwargs,
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(
f"data path {data_path} of files with builder {builder} is invalid"
) from e
except datasets.exceptions.DatasetGenerationError as e:
raise ValueError(
f"failed to generate the dataset \
from the builder {builder} and data files {data_path}."
) from e
else:
# CASE 3: User passes files/folder/pattern/HF_dataset without builder
try:
dataset = datasets.load_dataset(
path=data_path, split=splitName, **kwargs
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(f"data path {data_path} is invalid") from e
except datasets.exceptions.DatasetGenerationError as e:
raise ValueError(
f"failed to generate the dataset \
from the provided data path {data_path}."
) from e
all_datasets.append(dataset)

# Validate all datasets to have same columns
validate_datasets(all_datasets)

# Concatenate all datasets
try:
raw_datasets = (
datasets.concatenate_datasets(all_datasets)
if len(all_datasets) > 1
else all_datasets[0]
# Handle file/directory not found
context = f"builder {builder}" if builder else f"path {path}"
raise ValueError(f"Data loading failed: invalid {context}.") from e
except datasets.exceptions.DatasetGenerationError as e:
context = (
f"builder {builder} and data_dir {data_dir}"
if builder and data_dir
else f"builder {builder}"
if builder
else f"path {path}"
)
except Exception as e:
raise ValueError(
f"An error occurred while concatenating datasets: {e}"
f"Failed to generate the dataset from the provided {context}."
) from e
return raw_datasets

if datafile:
loader = get_loader_for_filepath(file_path=datafile)
if loader in (None, ""):
raise ValueError(f"data path is invalid [{datafile}]")
return _load_dataset(builder=loader, data_files=[datafile])

data_paths = datasetconfig.data_paths
builder = datasetconfig.builder
all_datasets = []

for data_path in data_paths:
# CASE 1: User passes directory
if os.path.isdir(data_path): # Checks if path exists and isdirectory
# Directory case
if builder:
# Load using a builder with a data_dir
dataset = _load_dataset(builder=builder, data_dir=data_path)
else:
# Load directly from the directory
dataset = _load_dataset(path=data_path)
else:
# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
effective_builder = (
builder if builder else get_loader_for_filepath(data_path)
)

if effective_builder:
# CASE 2: Files passed with builder. Load using the builder and specific files
dataset = _load_dataset(
builder=effective_builder, data_files=[data_path]
)
else:
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
dataset = _load_dataset(path=data_path)

all_datasets.append(dataset)

# Validate all datasets to have same columns
validate_datasets(all_datasets)

# Concatenate all datasets
try:
raw_datasets = (
datasets.concatenate_datasets(all_datasets)
if len(all_datasets) > 1
else all_datasets[0]
)
except Exception as e:
raise ValueError(
f"An error occurred while concatenating datasets: {e}"
) from e
return raw_datasets

def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig], **extra_kwargs
Expand Down

0 comments on commit 26e522e

Please sign in to comment.