Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformers] Prompt masking #2192

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/sparseml/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TextGenerationDataset(RegistryMixin):
"""

PROMPT_KEY = "prompt"
MASK_KEY = "mask"

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_remove_columns_from_dataset(
remove_columns.remove(self.text_column)
if self.PROMPT_KEY in remove_columns:
remove_columns.remove(self.PROMPT_KEY)
if "mask" in remove_columns:
remove_columns.remove("mask")
if self.MASK_KEY in remove_columns:
remove_columns.remove(self.MASK_KEY)

return list(remove_columns)
32 changes: 19 additions & 13 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,47 +559,53 @@ def fetch_recipe_path(target: str):
return recipe_path


def generate_mask(string: str, prompt: str, censor: str) -> str:
def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> str:
"""
Generate a mask based on provided prompt and censor strings to obscure
characters in the input string.
Generate a mask based on provided prompt and response strings to obscure
characters in the input string. Prompt will be masked and string in response
will be kept represented by 0 - remove and 1 - keep.
By default, non-reponse wrapped strings will be matched with 0

Args:
:param string: The input string to be masked.
:param prompt: The prompt string to identify characters to keep visible.
:param censor: The censor string to identify characters to obscure.
:param prompt: The prompt string to identify characters to obscure.
:param response: The response string to identify characters to keep visible.

Returns:
str: A string representing the mask where '1' indicates visible
characters and '0' indicates obscured characters.

"""
if prompt is None:
prompt = ""
horheynm marked this conversation as resolved.
Show resolved Hide resolved

mask = ["1"] * len(string)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
is_prompt = True
counter = 0
for i, char in enumerate(string):
if not is_prompt:
if is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and char == prompt[counter]:
counter += 1
elif is_prompt and char == censor[counter]:
elif is_prompt and char == response[counter]:
counter += 1
else:
counter = 0

if counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter
if len(prompt) > 0 and counter == len(prompt) and not is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter

counter = 0
is_prompt = True

if counter == len(censor) and is_prompt:
mask[i - counter + 1 : i + 1] = ["0"] * counter
if counter == len(response) and is_prompt:
mask[i - counter + 1 : i + 1] = ["1"] * counter

counter = 0
is_prompt = False

if prompt.startswith(char) or censor.startswith(char):
if prompt.startswith(char) or response.startswith(char):
counter = 1

return "".join(mask)
35 changes: 26 additions & 9 deletions tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,43 @@ def test_save_zoo_directory(tmp_path, stub):


@pytest.mark.parametrize(
"string, prompt, censor, expected_mask",
"string, response, prompt, expected_mask",
[
("[foo]hello\n\n[bar]world", "[foo]", "[bar]", "1111111111110000000000"),
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("000000000000" "1111111111"),
horheynm marked this conversation as resolved.
Show resolved Hide resolved
),
(
(
"[Instruction]python is\n\n" # 24
"[Response]great\n\n" # 17
"[Instruction]What about Java" # 28
"[Response]Meh" # 13
),
"[Instruction]",
"[Response]",
"[Instruction]",
(
"111111111111111111111111" # 24
"00000000000000000" # 17
"1111111111111111111111111111" # 28
"0000000000000" # 13
"000000000000000000000000" # 24
"11111111111111111" # 17
"0000000000000000000000000000" # 28
"1111111111111" # 13
),
),
(
("[foo]hello\n\n" "[bar]world"),
"[bar]",
None,
("000000000000" "1111111111"),
),
(
("hello\n\n" "[bar]world"),
"[bar]",
None,
("0000000" "1111111111"),
),
],
)
def test_generate_mask(string, prompt, censor, expected_mask):
assert generate_mask(string, prompt, censor) == expected_mask
def test_generate_mask(string, response, prompt, expected_mask):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
assert generate_mask(string, response, prompt) == expected_mask
Loading