Skip to content

Commit

Permalink
Custom datasets (#2115)
Browse files Browse the repository at this point in the history
* draft

* custom dataset pass in func name

* merge main, dataclass types

* comments

* clean up

* lint

* comments

* Fixing custom evolved dataset

* quality fix

* custom path

* Update src/sparseml/transformers/finetune/data/data_args.py

Co-authored-by: Rahul Tuli <[email protected]>

* Update src/sparseml/transformers/finetune/data/custom.py

Co-authored-by: Rahul Tuli <[email protected]>

* com

---------

Co-authored-by: Abhinav Agarwalla <[email protected]>
Co-authored-by: abhinavnmagic <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
Co-authored-by: Sara Adkins <[email protected]>
  • Loading branch information
5 people authored Mar 13, 2024
1 parent f80943e commit 708a341
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def lm_eval_harness(
:param kwargs: additional keyword arguments to pass to the
lm-evaluation-harness. For example, `limit`
"""

kwargs["limit"] = int(limit) if (limit := kwargs.get("limit")) else None

tokenizer = SparseAutoTokenizer.from_pretrained(model_path)
Expand Down
19 changes: 18 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from datasets.dataset_dict import Dataset, DatasetDict

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
from sparsezoo.utils.helpers import import_from_path


@TextGenerationDataset.register(name="custom", alias=["json", "csv"])
Expand Down Expand Up @@ -55,9 +59,21 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
raw_dataset = super().get_raw_dataset()

if self.preprocessing_func is not None:

if callable(self.preprocessing_func):
func = self.preprocessing_func
elif ":" in self.preprocessing_func:
# load func_name from "/path/to/file.py:func_name"
func = import_from_path(self.preprocessing_func)
else:
# load from the registry
func = PreprocessingFunctionRegistry.get_value_from_registry(
name=self.preprocessing_func
)

raw_dataset = self.map(
raw_dataset,
function=self.preprocessing_func,
function=func,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
desc="Applying custom func to the custom dataset",
Expand All @@ -82,6 +98,7 @@ def get_remove_columns_from_dataset(
self, raw_dataset: Union[DatasetDict, Dataset]
) -> List[str]:
"""Remove redandant columns from the dataset for processing"""

remove_columns = raw_dataset.column_names
if isinstance(remove_columns, Dict):
remove_columns = raw_dataset[list(raw_dataset.keys())[0]].column_names
Expand Down
11 changes: 9 additions & 2 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
metadata={"help": "Column names to remove after preprocessing custom datasets"},
)

preprocessing_func: Optional[Callable] = field(
default=None, metadata={"help": "The preprcessing function to apply"}
preprocessing_func: Union[None, str, Callable] = field(
default=None,
metadata={
"help": (
"The preprocessing function to apply ",
"or the preprocessing func name in "
"src/sparseml/transformers/utils/preprocessing_functions.py",
)
},
)


Expand Down
33 changes: 33 additions & 0 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_raw_dataset(
:return: the requested dataset
"""

raw_datasets = load_dataset(
data_args.dataset,
data_args.dataset_config_name,
Expand Down Expand Up @@ -125,6 +126,7 @@ def make_dataset_splits(
tokenized_datasets = {"train": tokenized_datasets}

train_split = eval_split = predict_split = calib_split = None

if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
Expand Down Expand Up @@ -218,4 +220,35 @@ def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str
if dir_dataset:
data_files[dir_name] = dir_dataset

return transform_dataset_keys(data_files)


def transform_dataset_keys(data_files: Dict[str, Any]):
"""
Transform dict keys to `train`, `val` or `test` for the given input dict
if matches exist with the existing keys. Note that there can only be one
matching file name.
Ex. Folder(train_eval.json) -> Folder(train.json)
Folder(train1.json, train2.json) -> Same
:param data_files: The dict where keys will be transformed
"""
keys = set(data_files.keys())

def transform_dataset_key(candidate: str) -> None:
for key in keys:
if candidate in key:
if key == candidate:
return
val = data_files.pop(key)
data_files[candidate] = val

def do_transform(candidate: str) -> bool:
return sum(candidate in key for key in keys) == 1

dataset_keys = ("train", "val", "test")
for dataset_key in dataset_keys:
if do_transform(dataset_key):
transform_dataset_key(dataset_key)

return data_files
1 change: 1 addition & 0 deletions src/sparseml/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .helpers import *
from .load_task_dataset import *
from .metrics import *
from .preprocessing_functions import *
from .sparse_config import *
from .sparse_model import *
from .sparse_tokenizer import *
29 changes: 29 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from sparsezoo.utils.registry import RegistryMixin


class PreprocessingFunctionRegistry(RegistryMixin):
...


@PreprocessingFunctionRegistry.register()
def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
return data

0 comments on commit 708a341

Please sign in to comment.