From 613527da328d3e0c5c7e235a679ea70cfe92235b Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 12 Dec 2024 12:44:51 +0530 Subject: [PATCH] Add multi turn chat support. Signed-off-by: Dushyant Behl --- .pylintrc | 4 +-- tests/data/test_data_preprocessing_utils.py | 35 ++++++++++++++++++++- tuning/config/configs.py | 15 +++++++++ tuning/data/data_preprocessing_utils.py | 17 ++++++++++ tuning/data/setup_dataprocessor.py | 25 +++++++++------ tuning/sft_trainer.py | 10 ++++++ 6 files changed, 94 insertions(+), 12 deletions(-) diff --git a/.pylintrc b/.pylintrc index 222bdf6cb..41f7e4e73 100644 --- a/.pylintrc +++ b/.pylintrc @@ -280,8 +280,8 @@ ignored-parents= # Maximum number of arguments for function / method. max-args=5 -# Maximum number of attributes for a class (see R0902). -max-attributes=7 +# Maximum number of attributes for a class (custom). +max-attributes=10 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 5559ac8ec..fed73f0e3 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -280,7 +280,8 @@ def test_is_pretokenized_data(data, result): @pytest.mark.parametrize( - "packing, response_template, formatted_train_dataset, max_seq_length, expected_collator", + "packing, response_template, formatted_train_dataset,\ + max_seq_length, instruction_template, expected_collator", [ ( False, @@ -291,6 +292,35 @@ def test_is_pretokenized_data(data, result): split="train", ), 1024, + None, + DataCollatorForCompletionOnlyLM, + ), + ( + False, + None, + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + 1024, + None, + DataCollatorForSeq2Seq, + ), + ( + False, + "\n### Label:", + datasets.load_dataset( + "json", + data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + split="train", + ), + 1024, + "\n### Text:", DataCollatorForCompletionOnlyLM, ), ( @@ -306,6 +336,7 @@ def test_is_pretokenized_data(data, result): ] ), 1024, + "\n### Text:", DataCollatorForSeq2Seq, ), ], @@ -315,6 +346,7 @@ def test_get_data_collator( response_template, formatted_train_dataset, max_seq_length, + instruction_template, expected_collator, ): """Ensure that the correct collator type is fetched based on the data args""" @@ -324,6 +356,7 @@ def test_get_data_collator( AutoTokenizer.from_pretrained(MODEL_NAME), is_pretokenized_dataset(formatted_train_dataset), max_seq_length, + instruction_template, ) assert isinstance(collator, expected_collator) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 222bf4424..6786d5410 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -102,6 +102,21 @@ class DataArguments: Supports both JSON and YAML based config files." }, ) + chat_template: str = field( + default=None, + metadata={ + "help": "chat template to use for tokenization. \ + No need to pass this if the tokenizer already has a chat_template \ + if passed, it will overwrite tokenizer.chat_template if it exists" + }, + ) + instruction_template: str = field( + default=None, + metadata={ + "help": "Should be provided for chat training. \ + Piece of text that determines the start of human response" + }, + ) @dataclass diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py index 589e4c9ef..1fa05d40d 100644 --- a/tuning/data/data_preprocessing_utils.py +++ b/tuning/data/data_preprocessing_utils.py @@ -28,6 +28,7 @@ def get_data_collator( tokenizer: AutoTokenizer, is_traindata_tokenized: bool, max_seq_length: int, + instruction_template: Optional[str], ) -> Callable: """Create and return the the appropriate collator type based on the configuration for packing, response_template, and dataset_text_field. @@ -43,12 +44,28 @@ def get_data_collator( Whether train Dataset is tokenized or not max_seq_length: int Max sequence length expected + instruction_template: str + str representing the human response in a chat template Returns: Callable Callable collator to be leveraged by the trainer. """ + if response_template and instruction_template: + # response_template_ids = tokenizer.encode( + # response_template, add_special_tokens=False + # )[2:] + # intruction_template_ids = tokenizer.encode( + # instruction_template, add_special_tokens=False + # )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template, + instruction_template=instruction_template, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + if not packing: # TODO: near term - how response template ids are parsed out needs to be cleaned. # The [2:] here applies if response template has \n prefix, it is needed to strip \n, diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index f9be9a23e..3ba3c9e5f 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -34,8 +34,8 @@ from tuning.data.data_processors import get_datapreprocessor # In future we may make the fields configurable -DEFAULT_JSON_INPUT_KEY = "input" -DEFAULT_JSON_OUTPUT_KEY = "output" +DEFAULT_INPUT_COLUMN = "input" +DEFAULT_OUTPUT_COLUMN = "output" # check if the provided dataset is pretokenized or not # the check is taken from trl @@ -151,12 +151,12 @@ def _get_dataset_formatting_handlers(data_args, packing): return [handler], dataset_text_field -### Data format 3 -def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs): +### Default Data format +def _get_default_dataset_handlers(data_args, tokenizer_kwargs): fn_kwargs = {} - fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY - fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY + fn_kwargs["input_field_name"] = DEFAULT_INPUT_COLUMN + fn_kwargs["output_field_name"] = DEFAULT_OUTPUT_COLUMN fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs kwargs = { @@ -177,7 +177,9 @@ def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs): # If a text field is specified, append the tokenizer's EOS token to it. # If a formatter template is provided, apply it and save the result. # Data remains un-tokenized. -# Data Format 3: JSON Dataset with Input/Output Fields +# Data Format 3: Chat datasets +# User provides response_template and instruction_template. +# Default Data Format: Dataset with Input/Output Fields # Combine input and output fields, tokenize the data, and apply input attention masking. # Requires both input and output fields; throws an error if missing. def _process_raw_data_args( @@ -239,9 +241,13 @@ def _process_raw_data_args( handlers, dataset_text_field = _get_dataset_formatting_handlers( data_args, packing ) + elif data_args.instruction_template and data_args.response_template: + # Data Format 3: Chat dataset with instruction and response template + # We don't do processing for chat dataset + handlers, dataset_text_field = [], None else: - # Data Format 3: JSON Dataset with Input/Output Fields - handlers, dataset_text_field = _get_default_json_dataset_handlers( + # Default Data Format: Dataset with Input/Output Fields + handlers, dataset_text_field = _get_default_dataset_handlers( data_args, tokenizer_kwargs ) @@ -329,6 +335,7 @@ def process_dataargs( tokenizer, is_tokenized_dataset, max_seq_length, + data_args.instruction_template, ) dataset_kwargs = {} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dc3e3733e..6a0c45b8d 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -287,6 +287,16 @@ def train( multiple_of=model_args.embedding_size_multiple_of, ) + if data_args.chat_template: + logger.info("adding chat_template to the tokenizer") + if tokenizer.chat_template: + logger.warning( + "replacing existing chat_template %s with the given chat_template %s", + tokenizer.chat_template, + data_args.chat_template, + ) + tokenizer.chat_template = data_args.chat_template + # Configure the collator and validate args related to packing prior to formatting the dataset data_collator = None logger.info("Packing is set to %s ", train_args.packing)