Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformers] Prompt masking #2192

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
19 changes: 15 additions & 4 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def tokenize_fn(data):
padding=self.padding,
max_length=self.max_seq_length,
truncation=True,
return_offsets_mapping=True,
)

# store unpadded prompt so we can mask out correct number of elements
Expand Down Expand Up @@ -156,16 +157,29 @@ def group_text_fn(data):
def label_fn(data):
# if the dataset uses prompts, mask them out so they don't contribute
# to the loss calculation
labels = data["input_ids"].copy()
if "offset_mapping" in data:
offset_mapping = data["offset_mapping"]
# get the character level mask
mask = data.get("mask")
if mask is not None:
for i, (start, end) in enumerate(offset_mapping):
# if any char is to be filtered
if "0" in mask[start:end]:
labels[i] = LABELS_MASK_VALUE
horheynm marked this conversation as resolved.
Show resolved Hide resolved

prompt_len = 0
if self.PROMPT_KEY in data:
prompt_len = len(data[self.PROMPT_KEY])
data["labels"] = data["input_ids"].copy()

data["labels"] = labels
data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

# mask out padding in the labels as well
padding = len(data["attention_mask"]) - sum(data["attention_mask"])
if padding > 0:
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding

return data

dataset = self.map(
Expand Down Expand Up @@ -206,8 +220,6 @@ def label_fn(data):
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Adding labels",
)
print(dataset.column_names)

return dataset

def map(
Expand All @@ -226,5 +238,4 @@ def map(
kwargs.pop("num_proc", None)
kwargs.pop("load_from_cache_file", None)
kwargs.pop("desc", None)

return dataset.map(**kwargs)
3 changes: 2 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
num_proc=self.data_args.preprocessing_num_workers,
desc="Removing unneeded columns",
)

return raw_dataset

def get_remove_columns_from_dataset(
Expand All @@ -108,5 +107,7 @@ def get_remove_columns_from_dataset(
remove_columns.remove(self.text_column)
if self.PROMPT_KEY in remove_columns:
remove_columns.remove(self.PROMPT_KEY)
if "mask" in remove_columns:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
remove_columns.remove("mask")

return list(remove_columns)
47 changes: 47 additions & 0 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"ALL_TASK_NAMES",
"create_fake_dataloader",
"POSSIBLE_TOKENIZER_FILES",
"generate_mask",
]


Expand Down Expand Up @@ -554,3 +555,49 @@ def fetch_recipe_path(target: str):
recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME)

return recipe_path


def generate_mask(string: str, prompt: str, censor: str) -> str:
"""
Generate a mask based on provided prompt and censor strings to obscure
characters in the input string.

Args:
:param string: The input string to be masked.
:param prompt: The prompt string to identify characters to keep visible.
:param censor: The censor string to identify characters to obscure.
horheynm marked this conversation as resolved.
Show resolved Hide resolved

Returns:
str: A string representing the mask where '1' indicates visible
characters and '0' indicates obscured characters.

"""
mask = ["1"] * len(string)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
is_prompt = True
counter = 0
for i, char in enumerate(string):
if not is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and char == prompt[counter]:
counter += 1
elif is_prompt and char == censor[counter]:
counter += 1
else:
counter = 0

if counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter
counter = 0
is_prompt = True

if counter == len(censor) and is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter
counter = 0
is_prompt = False

if prompt.startswith(char) or censor.startswith(char):
counter = 1

return "".join(mask)
4 changes: 4 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Dict

from sparseml.transformers.utils.helpers import generate_mask
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
data["mask"] = generate_mask(
data["text"], prompt="[Instruction]", censor="[Response]"
)
return data
27 changes: 27 additions & 0 deletions tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from accelerate import init_empty_weights
from sparseml.transformers.utils.helpers import (
create_fake_dataloader,
generate_mask,
infer_recipe_from_model_path,
is_transformer_model,
resolve_recipe_file,
Expand Down Expand Up @@ -166,3 +167,29 @@ def test_save_zoo_directory(tmp_path, stub):
assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"string, prompt, censor, expected_mask",
[
("[foo]hello\n\n[bar]world", "[foo]", "[bar]", "1111111111110000000000"),
(
(
"[Instruction]python is\n\n" # 24
"[Response]great\n\n" # 17
"[Instruction]What about Java" # 28
"[Response]Meh" # 13
),
"[Instruction]",
"[Response]",
(
"111111111111111111111111" # 24
"00000000000000000" # 17
"1111111111111111111111111111" # 28
"0000000000000" # 13
),
),
],
)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
def test_generate_mask(string, prompt, censor, expected_mask):
assert generate_mask(string, prompt, censor) == expected_mask
Loading