Skip to content

Commit

Permalink
Updates to enable ultrachat200k
Browse files Browse the repository at this point in the history
Ultrachat200k has 2 splits for training, one for sft and another for dpo. As a result it doesn't have a "train" split per se. This PR allows for a train_sft alternative.
  • Loading branch information
anmarques authored Apr 1, 2024
1 parent f784980 commit 769abb3
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,12 @@ def make_dataset_splits(
train_split = eval_split = predict_split = calib_split = None

if do_train:
if "train" not in tokenized_datasets:
if "train" in tokenized_datasets:
train_split = tokenized_datasets["train"]
elif "train_sft" in tokenized_datasets:
train_split = tokenized_datasets["train_sft"]
else:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_datasets["train"]
if do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
Expand All @@ -142,7 +145,11 @@ def make_dataset_splits(
if do_oneshot:
calib_split = tokenized_datasets.get("calibration")
if calib_split is None:
if "train" not in tokenized_datasets:
if "train" in tokenized_datasets:
train_split = tokenized_datasets["train"]
elif "train_sft" in tokenized_datasets:
train_split = tokenized_datasets["train_sft"]
else:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["train"]

Expand Down

0 comments on commit 769abb3

Please sign in to comment.