Skip to content

Commit

Permalink
Add multi turn chat support.
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Dec 17, 2024
1 parent 4441948 commit 613527d
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ ignored-parents=
# Maximum number of arguments for function / method.
max-args=5

# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of attributes for a class (custom).
max-attributes=10

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
Expand Down
35 changes: 34 additions & 1 deletion tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def test_is_pretokenized_data(data, result):


@pytest.mark.parametrize(
"packing, response_template, formatted_train_dataset, max_seq_length, expected_collator",
"packing, response_template, formatted_train_dataset,\
max_seq_length, instruction_template, expected_collator",
[
(
False,
Expand All @@ -291,6 +292,35 @@ def test_is_pretokenized_data(data, result):
split="train",
),
1024,
None,
DataCollatorForCompletionOnlyLM,
),
(
False,
None,
Dataset.from_list(
[
{
"input_ids": [9437, 29, 210],
"attention_mask": [1, 1, 1],
"labels": [1, 20, 30],
}
]
),
1024,
None,
DataCollatorForSeq2Seq,
),
(
False,
"\n### Label:",
datasets.load_dataset(
"json",
data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
split="train",
),
1024,
"\n### Text:",
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -306,6 +336,7 @@ def test_is_pretokenized_data(data, result):
]
),
1024,
"\n### Text:",
DataCollatorForSeq2Seq,
),
],
Expand All @@ -315,6 +346,7 @@ def test_get_data_collator(
response_template,
formatted_train_dataset,
max_seq_length,
instruction_template,
expected_collator,
):
"""Ensure that the correct collator type is fetched based on the data args"""
Expand All @@ -324,6 +356,7 @@ def test_get_data_collator(
AutoTokenizer.from_pretrained(MODEL_NAME),
is_pretokenized_dataset(formatted_train_dataset),
max_seq_length,
instruction_template,
)
assert isinstance(collator, expected_collator)

Expand Down
15 changes: 15 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ class DataArguments:
Supports both JSON and YAML based config files."
},
)
chat_template: str = field(
default=None,
metadata={
"help": "chat template to use for tokenization. \
No need to pass this if the tokenizer already has a chat_template \
if passed, it will overwrite tokenizer.chat_template if it exists"
},
)
instruction_template: str = field(
default=None,
metadata={
"help": "Should be provided for chat training. \
Piece of text that determines the start of human response"
},
)


@dataclass
Expand Down
17 changes: 17 additions & 0 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_data_collator(
tokenizer: AutoTokenizer,
is_traindata_tokenized: bool,
max_seq_length: int,
instruction_template: Optional[str],
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Expand All @@ -43,12 +44,28 @@ def get_data_collator(
Whether train Dataset is tokenized or not
max_seq_length: int
Max sequence length expected
instruction_template: str
str representing the human response in a chat template
Returns:
Callable
Callable collator to be leveraged by the trainer.
"""

if response_template and instruction_template:
# response_template_ids = tokenizer.encode(
# response_template, add_special_tokens=False
# )[2:]
# intruction_template_ids = tokenizer.encode(
# instruction_template, add_special_tokens=False
# )[2:]
return DataCollatorForCompletionOnlyLM(
response_template=response_template,
instruction_template=instruction_template,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if not packing:
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
Expand Down
25 changes: 16 additions & 9 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from tuning.data.data_processors import get_datapreprocessor

# In future we may make the fields configurable
DEFAULT_JSON_INPUT_KEY = "input"
DEFAULT_JSON_OUTPUT_KEY = "output"
DEFAULT_INPUT_COLUMN = "input"
DEFAULT_OUTPUT_COLUMN = "output"

# check if the provided dataset is pretokenized or not
# the check is taken from trl
Expand Down Expand Up @@ -151,12 +151,12 @@ def _get_dataset_formatting_handlers(data_args, packing):
return [handler], dataset_text_field


### Data format 3
def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs):
### Default Data format
def _get_default_dataset_handlers(data_args, tokenizer_kwargs):

fn_kwargs = {}
fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY
fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY
fn_kwargs["input_field_name"] = DEFAULT_INPUT_COLUMN
fn_kwargs["output_field_name"] = DEFAULT_OUTPUT_COLUMN
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs

kwargs = {
Expand All @@ -177,7 +177,9 @@ def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs):
# If a text field is specified, append the tokenizer's EOS token to it.
# If a formatter template is provided, apply it and save the result.
# Data remains un-tokenized.
# Data Format 3: JSON Dataset with Input/Output Fields
# Data Format 3: Chat datasets
# User provides response_template and instruction_template.
# Default Data Format: Dataset with Input/Output Fields
# Combine input and output fields, tokenize the data, and apply input attention masking.
# Requires both input and output fields; throws an error if missing.
def _process_raw_data_args(
Expand Down Expand Up @@ -239,9 +241,13 @@ def _process_raw_data_args(
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
)
elif data_args.instruction_template and data_args.response_template:
# Data Format 3: Chat dataset with instruction and response template
# We don't do processing for chat dataset
handlers, dataset_text_field = [], None
else:
# Data Format 3: JSON Dataset with Input/Output Fields
handlers, dataset_text_field = _get_default_json_dataset_handlers(
# Default Data Format: Dataset with Input/Output Fields
handlers, dataset_text_field = _get_default_dataset_handlers(
data_args, tokenizer_kwargs
)

Expand Down Expand Up @@ -329,6 +335,7 @@ def process_dataargs(
tokenizer,
is_tokenized_dataset,
max_seq_length,
data_args.instruction_template,
)

dataset_kwargs = {}
Expand Down
10 changes: 10 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ def train(
multiple_of=model_args.embedding_size_multiple_of,
)

if data_args.chat_template:
logger.info("adding chat_template to the tokenizer")
if tokenizer.chat_template:
logger.warning(
"replacing existing chat_template %s with the given chat_template %s",
tokenizer.chat_template,
data_args.chat_template,
)
tokenizer.chat_template = data_args.chat_template

# Configure the collator and validate args related to packing prior to formatting the dataset
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)
Expand Down

0 comments on commit 613527d

Please sign in to comment.