Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add early stopping #134

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 96 additions & 31 deletions training/run_distillation.py
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@
WhisperForConditionalGeneration,
WhisperProcessor,
WhisperTokenizerFast,
get_scheduler
get_scheduler,
)
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
@@ -470,6 +470,49 @@ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> D
return batch


class EarlyStopping:
"""
Monitor the total eval loss and stop training when it stops improving.
Args:
patience (:obj: `int`)
Number of checks / epochs with no improvement after which training will be \
stopped.
min_delta (:obj: `float`)
Minimum change in the monitored total eval loss to qualify as an \
improvement, i.e. an absolute change of less than or equal to \
min_delta, will count as no improvement.
"""

def __init__(self, patience: int = 3, min_delta: float = 0.001):
self.patience: int = patience
self.min_delta: float = min_delta
self.counter: int = 0
self.best_loss: Optional[float] = None
self.early_stop: bool = False

def __call__(self, val_loss: float, epoch: int):
"""
Call this method if cur_step % eval_steps == 0 or cur_step == total_train_steps\
with its corresponding validation loss.

Args:
val_loss (:obj: 'float'): Current epoch's validation loss.
epoch (:obj: 'int'): Current epoch number.
"""
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
logger.info(f"Increased early stopping counter at epoch {epoch}: {self.counter}/{self.patience}.")

if self.counter >= self.patience:
logger.info(f"Early stopping at epoch {epoch}")
self.early_stop = True


def log_metric(
accelerator,
metrics: Dict,
@@ -656,7 +699,7 @@ def load_multiple_datasets(
if use_pseudo_labels:
if "whisper_transcript" not in dataset_features:
raise ValueError(
f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure"
f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure "
"pseudo-labels are present in the dataset under this column name, or train directly on the text "
"labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`."
)
@@ -790,11 +833,7 @@ def main():

accelerator.init_trackers(
project_name=data_args.wandb_project,
init_kwargs={
"wandb": {"name": data_args.wandb_name,
"dir": data_args.wandb_dir}
}

init_kwargs={"wandb": {"name": data_args.wandb_name, "dir": data_args.wandb_dir}},
)

# 3. Set-up basic logging
@@ -999,13 +1038,12 @@ def set_trainable_parameters(module, requires_grad=False):
if training_args.freeze_encoder:
set_trainable_parameters(student_model.model.encoder, requires_grad=False)
student_model.model.encoder.gradient_checkpointing = False

if training_args.freeze_decoder:
set_trainable_parameters(student_model.model.decoder, requires_grad=False)
student_model.model.decoder.gradient_checkpointing = False
# un-freeze LM head parameters (and consequently word embeddings), frozen when frozing decoder since tied word embedding and LM head
set_trainable_parameters(student_model.proj_out, requires_grad=True)

set_trainable_parameters(student_model.proj_out, requires_grad=True)

if training_args.freeze_embed_positions:
# set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False)
@@ -1014,7 +1052,7 @@ def set_trainable_parameters(module, requires_grad=False):
logger.info(
"Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`."
)

logger.info(
f"Number of trainable parameters: {sum(p.numel() for p in student_model.parameters() if p.requires_grad):.3e}"
)
@@ -1349,12 +1387,12 @@ def compute_metrics(preds, labels):
eval_steps = training_args.eval_steps

# 13. Define optimizer, LR scheduler, collator

forbidden_module = [
module
for module, flag in [
(student_model.model.encoder, training_args.freeze_encoder),
(student_model.model.decoder, training_args.freeze_decoder)
(student_model.model.decoder, training_args.freeze_decoder),
]
if flag
] or None
@@ -1503,6 +1541,26 @@ def generate_step(batch):
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
return output_ids

def push_model_to_hub(
training_args: DistillationTrainingArguments,
repo_name: str,
cur_step: int,
) -> None:
upload_folder(
folder_path=training_args.output_dir,
repo_id=repo_name,
repo_type="model",
commit_message=f"Saving final weights of step {cur_step}",
)

def unwrap_and_save(
training_args: DistillationTrainingArguments,
accelerator: Accelerator,
student_model: WhisperForConditionalGeneration,
) -> None:
student_model = accelerator.unwrap_model(student_model)
student_model.save_pretrained(training_args.output_dir)

logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
if not data_args.streaming:
@@ -1559,6 +1617,8 @@ def generate_step(batch):
else:
resume_step = None

early_stopping = EarlyStopping()

for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
train_dataloader = DataLoader(
@@ -1592,6 +1652,7 @@ def generate_step(batch):
if accelerator.sync_gradients:
steps_trained_progress_bar.update(1)
cur_step += 1
best_model_tag = ""

if cur_step % training_args.logging_steps == 0:
steps_trained_progress_bar.write(
@@ -1611,19 +1672,19 @@ def generate_step(batch):

# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
if early_stopping.counter == 0:
best_model_tag = "-best"

intermediate_dir = os.path.join(
training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}{best_model_tag}"
)
accelerator.save_state(output_dir=intermediate_dir)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)

if training_args.push_to_hub:
upload_folder(
folder_path=training_args.output_dir,
repo_id=repo_name,
repo_type="model",
commit_message=f"Saving train state of step {cur_step}",
)
push_model_to_hub(training_args, repo_name, cur_step)

if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start
@@ -1709,20 +1770,24 @@ def generate_step(batch):
# flush the train metrics
train_start = time.time()

# break condition
if cur_step == total_train_steps:
# Check early stopping condition
early_stopping(float(eval_metrics["loss"]), epoch)

if early_stopping.early_stop:
if training_args.push_to_hub:
push_model_to_hub(training_args, repo_name, cur_step)

unwrap_and_save(training_args, accelerator, student_model)

# un-wrap student model for save
student_model = accelerator.unwrap_model(student_model)
student_model.save_pretrained(training_args.output_dir)
continue_training = False
break

# break condition
if cur_step == total_train_steps:
if training_args.push_to_hub:
upload_folder(
folder_path=training_args.output_dir,
repo_id=repo_name,
repo_type="model",
commit_message=f"Saving final weights of step {cur_step}",
)
push_model_to_hub(training_args, repo_name, cur_step)

unwrap_and_save(training_args, accelerator, student_model)

continue_training = False
break