diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index dab68386c..a0dc674b8 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1134,7 +1134,7 @@ def test_malformatted_data(): data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = MALFORMATTED_DATA - with pytest.raises(DatasetGenerationError): + with pytest.raises((DatasetGenerationError, ValueError)): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) @@ -1143,7 +1143,7 @@ def test_empty_data(): data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = EMPTY_DATA - with pytest.raises(DatasetGenerationError): + with pytest.raises((DatasetGenerationError, ValueError)): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 4e772e6d9..92b70526d 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -81,122 +81,102 @@ def load_dataset( if (not datafile) and (not datasetconfig): raise ValueError("Either datafile or datasetconfig must be set") - if datafile: - files = [datafile] - loader = get_loader_for_filepath(file_path=datafile) - if loader in (None, ""): - raise ValueError(f"data path is invalid [{', '.join(files)}]") + def _load_dataset(path=None, builder=None, data_files=None, data_dir=None): + """ + Helper function to load a dataset using datasets.load_dataset + with standardized exception handling. + + Args: + path: The path argument for load_dataset (could be a directory, file, builder, etc.) + builder: Optional builder to use if provided. + data_files: Optional data_files list if loading from files. + data_dir: Optional data_dir if loading from a directory with a builder. + Returns: dataset + """ + + load_kwargs = {**kwargs, "split": splitName} + if data_dir is not None: + load_kwargs["data_dir"] = data_dir + if data_files is not None: + load_kwargs["data_files"] = data_files + + # Determine the `path` parameter for load_dataset + load_path = builder if builder else path + try: - return datasets.load_dataset( - loader, data_files=files, split=splitName, **kwargs - ) + return datasets.load_dataset(path=load_path, **load_kwargs) except DatasetNotFoundError as e: + # Reraise with a more context-specific message if needed raise e except FileNotFoundError as e: - raise ValueError(f"data path is invalid [{', '.join(files)}]") from e - elif datasetconfig: - data_paths = datasetconfig.data_paths - all_datasets = [] - for _, data_path in enumerate(data_paths): - builder = datasetconfig.builder - # CASE 1: User passes directory - if os.path.isdir(data_path): # Checks if path exists and isdir - if builder: - # From given directory takes datafiles with builder - try: - dataset = datasets.load_dataset( - path=builder, - data_dir=data_path, - split=splitName, - **kwargs, - ) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError( - f"data path of directory {data_path} \ - with builder {builder} is invalid." - ) from e - except datasets.exceptions.DatasetGenerationError as e: - raise ValueError( - f"failed to generate the dataset from \ - the provided builder {builder} and directory {data_path}." - ) from e - else: - # Pass directory to HF directly - try: - dataset = datasets.load_dataset( - path=data_path, split=splitName, **kwargs - ) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError( - f"data path of directory {data_path} is invalid." - ) from e - except datasets.exceptions.DatasetGenerationError as e: - raise ValueError( - f"failed to generate the dataset \ - from the provided directory {data_path}." - ) from e - else: - # CASE OF NON-DIRECTORY - # If user did not pass builder - if builder is None: - builder = get_loader_for_filepath(data_path) - - # CASE 2: Files passed with builder - if builder: - try: - dataset = datasets.load_dataset( - path=builder, - data_files=[data_path], - split=splitName, - **kwargs, - ) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError( - f"data path {data_path} of files with builder {builder} is invalid" - ) from e - except datasets.exceptions.DatasetGenerationError as e: - raise ValueError( - f"failed to generate the dataset \ - from the builder {builder} and data files {data_path}." - ) from e - else: - # CASE 3: User passes files/folder/pattern/HF_dataset without builder - try: - dataset = datasets.load_dataset( - path=data_path, split=splitName, **kwargs - ) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError(f"data path {data_path} is invalid") from e - except datasets.exceptions.DatasetGenerationError as e: - raise ValueError( - f"failed to generate the dataset \ - from the provided data path {data_path}." - ) from e - all_datasets.append(dataset) - - # Validate all datasets to have same columns - validate_datasets(all_datasets) - - # Concatenate all datasets - try: - raw_datasets = ( - datasets.concatenate_datasets(all_datasets) - if len(all_datasets) > 1 - else all_datasets[0] + # Handle file/directory not found + context = f"builder {builder}" if builder else f"path {path}" + raise ValueError(f"Data loading failed: invalid {context}.") from e + except datasets.exceptions.DatasetGenerationError as e: + context = ( + f"builder {builder} and data_dir {data_dir}" + if builder and data_dir + else f"builder {builder}" + if builder + else f"path {path}" ) - except Exception as e: raise ValueError( - f"An error occurred while concatenating datasets: {e}" + f"Failed to generate the dataset from the provided {context}." ) from e - return raw_datasets + + if datafile: + loader = get_loader_for_filepath(file_path=datafile) + if loader in (None, ""): + raise ValueError(f"data path is invalid [{datafile}]") + return _load_dataset(builder=loader, data_files=[datafile]) + + data_paths = datasetconfig.data_paths + builder = datasetconfig.builder + all_datasets = [] + + for data_path in data_paths: + # CASE 1: User passes directory + if os.path.isdir(data_path): # Checks if path exists and isdirectory + # Directory case + if builder: + # Load using a builder with a data_dir + dataset = _load_dataset(builder=builder, data_dir=data_path) + else: + # Load directly from the directory + dataset = _load_dataset(path=data_path) + else: + # Non-directory (file, pattern, HF dataset name) + # If no builder provided, attempt to infer one + effective_builder = ( + builder if builder else get_loader_for_filepath(data_path) + ) + + if effective_builder: + # CASE 2: Files passed with builder. Load using the builder and specific files + dataset = _load_dataset( + builder=effective_builder, data_files=[data_path] + ) + else: + # CASE 3: User passes files/folder/pattern/HF_dataset which has no builder + dataset = _load_dataset(path=data_path) + + all_datasets.append(dataset) + + # Validate all datasets to have same columns + validate_datasets(all_datasets) + + # Concatenate all datasets + try: + raw_datasets = ( + datasets.concatenate_datasets(all_datasets) + if len(all_datasets) > 1 + else all_datasets[0] + ) + except Exception as e: + raise ValueError( + f"An error occurred while concatenating datasets: {e}" + ) from e + return raw_datasets def _process_dataset_configs( self, dataset_configs: List[DataSetConfig], **extra_kwargs