diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py index a92d492bb06..d4794e80220 100644 --- a/src/sparseml/transformers/finetune/data/data_helpers.py +++ b/src/sparseml/transformers/finetune/data/data_helpers.py @@ -49,9 +49,17 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ - num_calibration_samples = num_calibration_samples or len(tokenized_dataset) + safe_calibration_samples = len(tokenized_dataset) + if num_calibration_samples is not None: + safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) + if safe_calibration_samples != num_calibration_samples: + LOGGER.warn( + f"Requested {num_calibration_samples} calibration samples but " + f"the provided dataset only has {safe_calibration_samples}. " + ) + shuffled_calibration = tokenized_dataset.shuffle() - shuffled_calibration = shuffled_calibration.select(range(num_calibration_samples)) + shuffled_calibration = shuffled_calibration.select(range(safe_calibration_samples)) dataloader_params = { "batch_size": 1,