Skip to content

Commit

Permalink
keeping JSON keys constant
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Jul 30, 2024
1 parent 007e3e5 commit 6faa4e5
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tuning/utils/preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

logger = logging.get_logger("sft_trainer_preprocessing")

# In future we may make the fields configurable
JSON_INPUT_KEY = "input"
JSON_OUTPUT_KEY = "output"

def validate_data_args(data_args: configs.DataArguments, packing: bool):

Expand Down Expand Up @@ -65,12 +68,12 @@ def validate_data_args(data_args: configs.DataArguments, packing: bool):
json_dataset = datasets.load_dataset(
"json", data_files=data_args.training_data_path
)
if "input" not in json_dataset["train"].column_names:
if JSON_INPUT_KEY not in json_dataset["train"].column_names:
raise ValueError(
"JSON should contain input field if no dataset_text_field or \
data_formatter_template specified"
)
if "output" not in json_dataset["train"].column_names:
if JSON_OUTPUT_KEY not in json_dataset["train"].column_names:
raise ValueError(
"JSON should contain output field if no dataset_text_field or \
data_formatter_template specified"
Expand Down Expand Up @@ -172,16 +175,16 @@ def format_dataset(
data_args.training_data_path,
tokenizer,
max_seq_length,
input_field_name="input",
output_field_name="output",
input_field_name=JSON_INPUT_KEY,
output_field_name=JSON_OUTPUT_KEY,
)
if data_args.validation_data_path:
eval_dataset = get_preprocessed_dataset(
data_args.validation_data_path,
tokenizer,
max_seq_length,
input_field_name="input",
output_field_name="output",
input_field_name=JSON_INPUT_KEY,
output_field_name=JSON_OUTPUT_KEY,
)

return train_dataset, eval_dataset, dataset_text_field
Expand Down

0 comments on commit 6faa4e5

Please sign in to comment.