Skip to content

Commit

Permalink
refactor dreambooth example to use modal.Volume
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Aug 2, 2023
1 parent c60bb35 commit 87895a0
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions 06_gpu_and_ml/dreambooth/dreambooth_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# It demonstrates a simple, productive, and cost-effective pathway
# to building on large pretrained models
# by using Modal's building blocks, like
# [GPU-accelerated](https://modal.com/docs/guide/gpu#using-a100-gpus-alpha) Modal Functions, [network file systems](/docs/guide/network-file-systems#network-file-systems) for caching, and [Modal webhooks](https://modal.com/docs/guide/webhooks#webhook).
# [GPU-accelerated](https://modal.com/docs/guide/gpu#using-a100-gpus-alpha) Modal Functions, [network volumes](/docs/guide/network-volumes#network-volumes) for caching, and [Modal webhooks](https://modal.com/docs/guide/webhooks#webhook).
#
# And with some light customization, you can use it to generate images of your pet!
#
Expand All @@ -33,8 +33,8 @@
from modal import (
Image,
Mount,
NetworkFileSystem,
Secret,
Volume,
Stub,
asgi_app,
method,
Expand Down Expand Up @@ -73,11 +73,12 @@
)
)

# A persisted network file system will store model artefacts across Modal app runs.
# A persisted `modal.Volume` will store model artefacts across Modal app runs.
# This is crucial as finetuning runs are separate from the Gradio app we run as a webhook.

volume = NetworkFileSystem.persisted("dreambooth-finetuning-vol")
volume = Volume.persisted("dreambooth-finetuning-volume")
MODEL_DIR = Path("/model")
stub.volume = volume

# ## Config
#
Expand Down Expand Up @@ -189,7 +190,7 @@ def load_images(image_urls):
@stub.function(
image=image,
gpu="A100", # finetuning is VRAM hungry, so this should be an A100
network_file_systems={
volumes={
str(
MODEL_DIR
): volume, # fine-tuned model will be stored at `MODEL_DIR`
Expand Down Expand Up @@ -258,24 +259,31 @@ def train(instance_example_urls):
print(exc.stderr.decode())
raise

# The trained model artefacts have been output to the volume mounted at `MODEL_DIR`.
# To persist these artefacts for use in future inference function calls, we 'commit' the changes
# to the volume.
stub.app.volume.commit()


# ## The inference function.
#
# To generate images from prompts using our fine-tuned model, we define a function called `inference`.
# In order to initialize the model just once on container startup, we use Modal's [container
# lifecycle](https://modal.com/docs/guide/lifecycle-functions) feature, which requires the function to be part
# of a class. The network file system is mounted at `MODEL_DIR`, so that the fine-tuned model created by `train` is then available to `inference`.
# of a class. The `modal.Volume` is mounted at `MODEL_DIR`, so that the fine-tuned model created by `train` is then available to `inference`.


@stub.cls(
image=image,
gpu="A100",
network_file_systems={str(MODEL_DIR): volume},
volumes={str(MODEL_DIR): volume},
)
class Model:
def __enter__(self):
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline
# Reload the modal.Volume to ensure the latest state is accessible.
stub.app.volume.reload()

# set up a hugging face inference pipeline using our model
ddim = DDIMScheduler.from_pretrained(MODEL_DIR, subfolder="scheduler")
Expand Down

0 comments on commit 87895a0

Please sign in to comment.