Skip to content

Commit

Permalink
refactor flan t5 finetune to use modal.Volume
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Aug 2, 2023
1 parent 87895a0 commit c15f360
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
1 change: 1 addition & 0 deletions 06_gpu_and_ml/dreambooth/dreambooth_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ class Model:
def __enter__(self):
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline

# Reload the modal.Volume to ensure the latest state is accessible.
stub.app.volume.reload()

Expand Down
29 changes: 22 additions & 7 deletions 06_gpu_and_ml/flan_t5/flan_t5_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pathlib import Path

from modal import Image, NetworkFileSystem, Stub, method, wsgi_app
from modal import Image, Volume, Stub, method, wsgi_app

VOL_MOUNT_PATH = Path("/vol")

Expand All @@ -37,7 +37,8 @@
)

stub = Stub(name="example-news-summarizer", image=image)
output_vol = NetworkFileSystem.persisted("finetune-vol")
output_vol = Volume.persisted("finetune-volume")
stub.volume = output_vol

# ## Finetuning Flan-T5 on XSum dataset
#
Expand All @@ -47,7 +48,7 @@
@stub.function(
gpu="A10g",
timeout=7200,
network_file_systems={VOL_MOUNT_PATH: output_vol},
volumes={VOL_MOUNT_PATH: output_vol},
)
def finetune(num_train_epochs: int = 1, size_percentage: int = 10):
from datasets import load_dataset
Expand All @@ -57,6 +58,7 @@ def finetune(num_train_epochs: int = 1, size_percentage: int = 10):
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
TrainerCallback,
)

# Use size percentage to retrieve subset of the dataset to iterate faster
Expand Down Expand Up @@ -121,6 +123,17 @@ def preprocess(batch):
pad_to_multiple_of=batch_size,
)

class CheckpointCallback(TrainerCallback):
def __init__(self, volume):
self.volume = volume

def on_save(self, args, state, control, **kwargs):
"""
Event called after a checkpoint save.
"""
print("running commit on modal.Volume after model checkpoint")
self.volume.commit()

training_args = Seq2SeqTrainingArguments(
# Save checkpoints to the mounted volume
output_dir=str(VOL_MOUNT_PATH / "model"),
Expand All @@ -142,6 +155,7 @@ def preprocess(batch):
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
callbacks=[CheckpointCallback(stub.app.volume)],
data_collator=data_collator,
train_dataset=tokenized_xsum_train,
eval_dataset=tokenized_xsum_test,
Expand All @@ -152,14 +166,15 @@ def preprocess(batch):
# Save the trained model and tokenizer to the mounted volume
model.save_pretrained(str(VOL_MOUNT_PATH / "model"))
tokenizer.save_pretrained(str(VOL_MOUNT_PATH / "tokenizer"))
stub.app.volume.commit()


# ## Monitoring Finetuning with Tensorboard
#
# Tensorboard is an application for visualizing training loss. In this example we
# serve it as a Modal WSGI app.
#
@stub.function(network_file_systems={VOL_MOUNT_PATH: output_vol})
@stub.function(volumes={VOL_MOUNT_PATH: output_vol})
@wsgi_app()
def monitor():
import tensorboard
Expand All @@ -181,7 +196,7 @@ def monitor():
#


@stub.cls(network_file_systems={VOL_MOUNT_PATH: output_vol})
@stub.cls(volumes={VOL_MOUNT_PATH: output_vol})
class Summarizer:
def __enter__(self):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
Expand Down Expand Up @@ -228,14 +243,14 @@ def main():
# Invoke model finetuning use the provided command below
#
# ```bash
# modal run --detach finetune.py::finetune --num-train-epochs=1 --size-percentage=10
# modal run --detach flan_t5_finetune.py::finetune --num-train-epochs=1 --size-percentage=10
# View the tensorboard logs at https://<username>--example-news-summarizer-monitor-dev.modal.run
# ```
#
# Invoke finetuned model inference via local entrypoint
#
# ```bash
# modal run finetune.py
# modal run flan_t5_finetune.py
# World number one Tiger Woods missed the cut at the US Open as he failed to qualify for the final round of the event in Los Angeles.
# ```
#

0 comments on commit c15f360

Please sign in to comment.