diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 354dfcccbe1..e7bc2d5b286 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -41,6 +41,7 @@ class TextGenerationDataset(RegistryMixin): """ PROMPT_KEY = "prompt" + MASK_KEY = "mask" def __init__( self, @@ -125,6 +126,7 @@ def tokenize_fn(data): padding=self.padding, max_length=self.max_seq_length, truncation=True, + return_offsets_mapping=True, ) # store unpadded prompt so we can mask out correct number of elements @@ -156,16 +158,29 @@ def group_text_fn(data): def label_fn(data): # if the dataset uses prompts, mask them out so they don't contribute # to the loss calculation + labels = data["input_ids"].copy() + if "offset_mapping" in data: + offset_mapping = data["offset_mapping"] + # get the character level mask + mask = data.get("mask") + if mask is not None: + for i, (start, end) in enumerate(offset_mapping): + # if any char is to be filtered + if "0" in mask[start:end]: + labels[i] = LABELS_MASK_VALUE + prompt_len = 0 if self.PROMPT_KEY in data: prompt_len = len(data[self.PROMPT_KEY]) - data["labels"] = data["input_ids"].copy() + + data["labels"] = labels data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len # mask out padding in the labels as well padding = len(data["attention_mask"]) - sum(data["attention_mask"]) if padding > 0: data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding + return data if raw_dataset is None: @@ -209,8 +224,6 @@ def label_fn(data): load_from_cache_file=not self.data_args.overwrite_cache, desc="Adding labels", ) - print(dataset.column_names) - return dataset def map( @@ -229,5 +242,4 @@ def map( kwargs.pop("num_proc", None) kwargs.pop("load_from_cache_file", None) kwargs.pop("desc", None) - return dataset.map(**kwargs) diff --git a/src/sparseml/transformers/finetune/data/custom.py b/src/sparseml/transformers/finetune/data/custom.py index f1bdcb6085f..55586e7d562 100644 --- a/src/sparseml/transformers/finetune/data/custom.py +++ b/src/sparseml/transformers/finetune/data/custom.py @@ -91,7 +91,6 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: num_proc=self.data_args.preprocessing_num_workers, desc="Removing unneeded columns", ) - return raw_dataset def get_remove_columns_from_dataset( @@ -108,5 +107,7 @@ def get_remove_columns_from_dataset( remove_columns.remove(self.text_column) if self.PROMPT_KEY in remove_columns: remove_columns.remove(self.PROMPT_KEY) + if self.MASK_KEY in remove_columns: + remove_columns.remove(self.MASK_KEY) return list(remove_columns) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index cb95d376a75..388fce18e73 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -54,6 +54,7 @@ "ALL_TASK_NAMES", "create_fake_dataloader", "POSSIBLE_TOKENIZER_FILES", + "generate_mask", "download_repo_from_huggingface_hub", "download_model_directory", ] @@ -544,6 +545,56 @@ def fetch_recipe_path(target: str): return recipe_path +def generate_mask(string: str, response: str, prompt: str = "") -> str: + """ + Generate a mask based on provided prompt and response strings to obscure + characters in the input string. Prompt will be masked and string in response + will be kept represented by 0 - remove and 1 - keep. + By default, non-reponse wrapped strings will be matched with 0 + + Args: + :param string: The input string to be masked. + :param prompt: The prompt string to identify characters to obscure. + :param response: The response string to identify characters to keep visible. + + Returns: + str: A string representing the mask where '1' indicates visible + characters and '0' indicates obscured characters. + + """ + + mask = ["1"] * len(string) + is_prompt = False if string.startswith(response) else True + counter = 0 + for i, char in enumerate(string): + if is_prompt: + mask[i] = "0" + + if counter > 0: + if not is_prompt and len(prompt) > 1 and char == prompt[counter]: + counter += 1 + elif is_prompt and char == response[counter]: + counter += 1 + else: + counter = 0 + + if len(prompt) > 0 and counter == len(prompt) and not is_prompt: + mask[i - counter + 1 : i + 1] = ["0"] * counter + + counter = 0 + is_prompt = True + + if counter == len(response) and is_prompt: + mask[i - counter + 1 : i + 1] = ["1"] * counter + + counter = 0 + is_prompt = False + + if prompt.startswith(char) or response.startswith(char): + counter = 1 + return "".join(mask) + + def download_repo_from_huggingface_hub(repo_id, **kwargs): """ Download relevant model files from the Hugging Face Hub diff --git a/src/sparseml/transformers/utils/preprocessing_functions.py b/src/sparseml/transformers/utils/preprocessing_functions.py index 8b019094377..b26bc3f587c 100644 --- a/src/sparseml/transformers/utils/preprocessing_functions.py +++ b/src/sparseml/transformers/utils/preprocessing_functions.py @@ -14,6 +14,7 @@ from typing import Dict +from sparseml.transformers.utils.helpers import generate_mask from sparsezoo.utils.registry import RegistryMixin @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict): PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:""" data["prompt"] = PROMPT_DICT.format_map(data) data["text"] = data["prompt"] + data["output"] + data["mask"] = generate_mask( + data["text"], prompt="[Instruction]", censor="[Response]" + ) return data diff --git a/tests/sparseml/transformers/finetune/test_finetune.py b/tests/sparseml/transformers/finetune/test_finetune.py index 9ecfe220122..7a290398527 100644 --- a/tests/sparseml/transformers/finetune/test_finetune.py +++ b/tests/sparseml/transformers/finetune/test_finetune.py @@ -33,6 +33,7 @@ oneshot, train, ) +from sparseml.transformers.utils.helpers import generate_mask def test_oneshot_and_finetune(tmp_path: Path): @@ -322,3 +323,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path): splits=splits, oneshot_device=device, ) + + +def test_finetune_wout_recipe_with_mask(tmp_path: Path): + recipe_str = None + model = "Xenova/llama2.c-stories15M" + device = "cuda:0" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + output_dir = tmp_path + max_steps = 50 + splits = "train" + + def preprocessing_func(example): + example["text"] = "[foo]" + example["text"] + "[bar] mask this" + example["mask"] = generate_mask( + example["text"], response="[bar]", prompt="[foo]" + ) + return example + + train( + model=model, + dataset=dataset, + output_dir=output_dir, + recipe=recipe_str, + max_steps=max_steps, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + preprocessing_func=preprocessing_func, + ) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 5e0ab5e93da..04ebf735d7e 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -21,6 +21,7 @@ from accelerate import init_empty_weights from sparseml.transformers.utils.helpers import ( create_fake_dataloader, + generate_mask, infer_recipe_from_model_path, is_transformer_model, resolve_recipe_file, @@ -166,3 +167,55 @@ def test_save_zoo_directory(tmp_path, stub): assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False) shutil.rmtree(path_to_training_outputs) shutil.rmtree(save_dir) + + +@pytest.mark.parametrize( + "string, response, prompt, expected_mask", + [ + ( + ("[foo]hello\n\n" "[bar]world"), + "[bar]", + "[foo]", + ("000000000000" "1111111111"), + ), + ( + ( + "[Instruction]python is\n\n" # 24 + "[Response]great\n\n" # 17 + "[Instruction]What about Java" # 28 + "[Response]Meh" # 13 + ), + "[Response]", + "[Instruction]", + ( + "000000000000000000000000" # 24 + "11111111111111111" # 17 + "0000000000000000000000000000" # 28 + "1111111111111" # 13 + ), + ), + ( + ("[foo]hello\n\n" "[bar]world"), + "[bar]", + None, + ("000000000000" "1111111111"), + ), + ( + ("hello\n\n" "[bar]world"), + "[bar]", + None, + ("0000000" "1111111111"), + ), + ( + ("[bar]world" "[foo]hello\n\n" "[bar]world"), + "[bar]", + "[foo]", + ("1111111111" "000000000000" "1111111111"), + ), + ], +) +def test_generate_mask(string, response, prompt, expected_mask): + if prompt is not None: + assert generate_mask(string, response, prompt) == expected_mask + else: + assert generate_mask(string, response) == expected_mask