Skip to content

Commit

Permalink
SFT dataset - model transform validation (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewldesousa authored Sep 4, 2024
1 parent 01bff03 commit d31649e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/torchtune/datasets/test_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return {"messages": messages}


class DummyTokenizerInvalidModelTransform(DummyTokenizer):
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
sample = super().__call__(sample)
del sample["tokens"]
return sample


class TestSFTDataset:
@pytest.fixture
def dialogue(self):
Expand Down Expand Up @@ -135,3 +142,17 @@ def test_error_for_invalid_messages(self, mock_load_dataset, invalid_dialogue):
msg = "system messages must come first"
with pytest.raises(ValueError, match=msg):
ds[0]

@mock.patch("torchtune.datasets._sft.load_dataset")
def test_error_for_invalid_tokenized_dict(self, mock_load_dataset, dialogue):
mock_load_dataset.return_value = dialogue

ds = SFTDataset(
source="iam/agoofy/goober",
message_transform=ToDummyMessages(),
model_transform=DummyTokenizerInvalidModelTransform(),
)

msg = "model_transform returned the following keys: mask. Must return 'tokens' and 'mask' as keys."
with pytest.raises(ValueError, match=msg):
ds[0]
8 changes: 8 additions & 0 deletions torchtune/datasets/_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:

tokenized_dict = self._model_transform(transformed_sample)

if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)

# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
Expand Down

0 comments on commit d31649e

Please sign in to comment.