Skip to content

Commit

Permalink
Test commit
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Dec 18, 2024
1 parent 4441948 commit 220efb6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
4 changes: 4 additions & 0 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataHandlerConfig:
class DataSetConfig:
name: str
data_paths: List[str]
builder: Optional[str] = None
sampling: Optional[float] = None
data_handlers: Optional[List[DataHandlerConfig]] = None

Expand Down Expand Up @@ -87,6 +88,9 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
)
p = _p
c.data_paths.append(p)
if "builder" in kwargs and kwargs["builder"] is not None:
builder = kwargs["builder"]
assert isinstance(builder, str), f"builder: {ratio} should be str with values in (json, text, parquet, arrow)"
if "sampling" in kwargs and kwargs["sampling"] is not None:
ratio = kwargs["sampling"]
assert isinstance(ratio, float) and (
Expand Down
80 changes: 57 additions & 23 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,30 +84,64 @@ def load_dataset(
if datafile:
files = [datafile]
loader = get_loader_for_filepath(file_path=datafile)
try:
return datasets.load_dataset(
loader,
data_files=files,
split=splitName,
**kwargs
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(f"data path is invalid [{', '.join(files)}]") from e
elif datasetconfig:
files = datasetconfig.data_paths
name = datasetconfig.name
# simple check to make sure all files are of same type.
extns = [get_extension(f) for f in files]
assert extns.count(extns[0]) == len(
extns
), f"All files in the dataset {name} should have the same extension"
loader = get_loader_for_filepath(file_path=files[0])

if loader in (None, ""):
raise ValueError(f"data path is invalid [{', '.join(files)}]")

try:
return datasets.load_dataset(
loader,
data_files=files,
split=splitName,
**kwargs,
)
except DatasetNotFoundError as e:
raise e
except FileNotFoundError as e:
raise ValueError(f"data path is invalid [{', '.join(files)}]") from e
data_paths = datasetconfig.data_paths
all_datasets = []
for _, data_path in enumerate(data_paths):
# CASE 1: User passes directory
if os.path.isdir(data_path): # Checks if path exists and isdir
if builder:
# From given directory takes datafiles with builder
dataset = datasets.load_dataset(path=builder, data_dir=data_path, split=splitName, **kwargs)
else:
# Pass directory to HF directly
dataset = datasets.load_dataset(path=data_path, split=splitName, **kwargs)
# If user did not pass builder
if builder is None:
builder = get_loader_for_filepath(data_path)
if builder:
# Path is a file with builder
dataset = datasets.load_dataset(path=builder, data_files=[data_path], split=splitName, **kwargs)
else:
# Path is HF dataset ID
dataset = datasets.load_dataset(path=data_path, split=splitName, **kwargs)
all_datasets.append(dataset)

# Concatenate all datasets
return datasets.concatenate_datasets(all_datasets) if len(all_datasets) > 1 else all_datasets[0]

# name = datasetconfig.name
# # simple check to make sure all files are of same type.
# extns = [get_extension(f) for f in files]
# assert extns.count(extns[0]) == len(
# extns
# ), f"All files in the dataset {name} should have the same extension"
# loader = get_loader_for_filepath(file_path=files[0])

# if loader in (None, ""):
# raise ValueError(f"data path is invalid [{', '.join(files)}]")

# try:
# return datasets.load_dataset(
# loader,
# data_files=files,
# ,
# )
# except DatasetNotFoundError as e:
# raise e
# except FileNotFoundError as e:
# raise ValueError(f"data path is invalid [{', '.join(files)}]") from e

def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig], **extra_kwargs
Expand Down
4 changes: 2 additions & 2 deletions tuning/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def get_loader_for_filepath(file_path: str) -> str:
return "text"
if ext in (".json", ".jsonl"):
return "json"
if ext in (".arrow"):
if ext in (".arrow",):
return "arrow"
if ext in (".parquet"):
if ext in (".parquet",):
return "parquet"
return ext

Expand Down

0 comments on commit 220efb6

Please sign in to comment.