Skip to content

Commit

Permalink
Update stable diffusion example (#953)
Browse files Browse the repository at this point in the history
* Update Stable Diffusion CLI example with Stable Diffusion 3.5

* Update some of the language

* Update

* Update

* Update

* Update

* Add latencies file

* Update

* Add model revision id

* text fixes, drop cuda image, faster transfer, faster batching

* remove inference time plot

---------

Co-authored-by: Charles Frye <[email protected]>
  • Loading branch information
yirenlu92 and charlesfrye authored Nov 1, 2024
1 parent 2ae5f07 commit 87237d0
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 124 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
242 changes: 118 additions & 124 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py
Original file line number Diff line number Diff line change
@@ -1,179 +1,173 @@
# ---
# output-directory: "/tmp/stable-diffusion"
# args: ["--prompt", "A 1600s oil painting of the New York City skyline"]
# runtimes: ["runc", "gvisor"]
# tags: ["use-case-image-video-3d"]
# ---
# # Stable Diffusion CLI
#
# This example shows Stable Diffusion 1.5 with a number of optimizations
# that makes it run faster on Modal. The example takes about 10s to cold start
# and about 1.0s per image generated.
#
# To use the XL 1.0 model, see the example posted [here](/docs/examples/stable_diffusion_xl).
#
# For instance, here are 9 images produced by the prompt
# `A 1600s oil painting of the New York City skyline`
#
# ![stable diffusion montage](./stable_diffusion_montage.png)
#
# As mentioned, we use a few optimizations to run this faster:
#
# * Use [run_function](/docs/reference/modal.Image#run_function) to download the model while building the container image
# * Use a [container lifecycle method](https://modal.com/docs/guide/lifecycle-functions) to initialize the model on container startup
# * Use A10G GPUs
# * Use 16 bit floating point math

# # Run Stable Diffusion 3.5 Large Turbo from the command line

# This example shows how to run [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) on Modal
# and generate images from your local command line.

# Inference takes about one minute to cold start,
# at which point images are generated at a rate of one image every 1-2 seconds
# for batch sizes between one and 16.

# Below are four images produced by the prompt
# "A princess riding on a pony".

# ![stable diffusion montage](./stable-diffusion-montage-princess.jpg)

# ## Basic setup
from __future__ import annotations

import io
import random
import time
from pathlib import Path

import modal

# All Modal programs need a [`App`](/docs/reference/modal.App) — an object that acts as a recipe for
MINUTES = 60

# All Modal programs need an [`App`](https://modal.com/docs/reference/modal.App) — an object that acts as a recipe for
# the application. Let's give it a friendly name.

app = modal.App("stable-diffusion-cli")

# ## Model dependencies
#
# Your model will be running remotely inside a container. We will be installing
# all the model dependencies in the next step. We will also be "baking the model"
# into the image by running a Python function as a part of building the image.
# This lets us start containers much faster, since all the data that's needed is
# already inside the image.

model_id = "Jiali/stable-diffusion-1.5"

image = modal.Image.debian_slim(python_version="3.10").pip_install(
"accelerate==0.29.2",
"diffusers==0.15.1",
"ftfy==6.2.0",
"safetensors==0.4.2",
"torch==2.2.2",
"torchvision",
"transformers~=4.25.1",
"triton~=2.2.0",
"xformers==0.0.25post1",
app = modal.App("example-stable-diffusion-cli")

# ## Configuring dependencies

# The model runs remotely inside a [container](https://modal.com/docs/guide/custom-container).
# That means we need to install the necessary dependencies in that container's image.

# Below, we start from a lightweight base Linux image
# and then install our Python dependencies, like Hugging Face's `diffusers` library and `torch`.

image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"accelerate==0.33.0",
"diffusers==0.31.0",
"huggingface-hub[hf_transfer]==0.25.2",
"sentencepiece==0.2.0",
"torch==2.5.1",
"torchvision==0.20.1",
"transformers~=4.44.0",
)
.entrypoint([]) # deactivate default entrypoint to reduce log verbosity
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster downloads
)

with image.imports():
import diffusers
import torch

# ## Implementing SD3.5 Large Turbo inference on Modal

# We wrap inference in a Modal [Cls](https://modal.com/docs/guide/lifecycle-methods)
# that ensures models are downloaded when we `build` our container image (just like our dependencies)
# and that models are loaded and then moved to the GPU when a new container starts.

# ## Using container lifecycle methods
#
# Modal lets you implement code that runs every time a container starts. This
# can be a huge optimization when you're calling a function multiple times,
# since Modal reuses the same containers when possible.
#
# The way to implement this is to turn the Modal function into a method on a
# class that also has lifecycle methods (decorated with `@enter()` and/or `@exit()`).
#
# We have also have applied a few model optimizations to make the model run
# faster. On an A10G, the model takes about 6.5s to load into memory, and then
# 1.6s per generation on average. On a T4, it takes 13s to load and 3.7s per
# generation. Other optimizations are also available [here](https://huggingface.co/docs/diffusers/optimization/fp16#memory-and-speed).
# The `run_inference` function just wraps a `diffusers` pipeline.
# It sends the output image back to the client as bytes.

# This is our Modal function. The function runs through the `StableDiffusionPipeline` pipeline.
# It sends the PIL image back to our CLI where we save the resulting image in a local file.
model_id = "adamo1139/stable-diffusion-3.5-large-turbo-ungated"
model_revision_id = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2"


@app.cls(image=image, gpu="A10G")
@app.cls(
image=image,
gpu="H100",
timeout=10 * MINUTES,
)
class StableDiffusion:
@modal.build()
@modal.enter()
def initialize(self):
scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
solver_order=2,
prediction_type="epsilon",
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
denoise_final=True, # important if steps are <= 10
low_cpu_mem_usage=True,
device_map="auto",
)
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
model_id,
scheduler=scheduler,
low_cpu_mem_usage=True,
device_map="auto",
revision=model_revision_id,
torch_dtype=torch.bfloat16,
)
self.pipe.enable_xformers_memory_efficient_attention()

@modal.enter()
def move_to_gpu(self):
self.pipe.to("cuda")

@modal.method()
def run_inference(
self, prompt: str, steps: int = 20, batch_size: int = 4
self, prompt: str, batch_size: int = 4, seed: int = None
) -> list[bytes]:
with torch.inference_mode():
with torch.autocast("cuda"):
images = self.pipe(
[prompt] * batch_size,
num_inference_steps=steps,
guidance_scale=7.0,
).images

# Convert to PNG bytes
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
print("seeding RNG with", seed)
torch.manual_seed(seed)
images = self.pipe(
prompt,
num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls
num_inference_steps=4, # turbo is tuned to run in four steps
guidance_scale=0.0, # turbo doesn't use CFG
max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts
).images

image_output = []
for image in images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
torch.cuda.empty_cache() # reduce fragmentation
return image_output


# This is the command we'll use to generate images. It takes a `prompt`,
# `samples` (the number of images you want to generate), `steps` which
# configures the number of inference steps the model will make, and `batch_size`
# which determines how many images to generate for a given prompt.
# ## Generating images from the command line

# This is the command we'll use to generate images. It takes a text `prompt`,
# a `batch_size` that determines the number of images to generate per prompt,
# and the number of times to run image generation (`samples`).

# You can also provide a `seed` to make sampling more deterministic.


@app.local_entrypoint()
def entrypoint(
prompt: str = "A 1600s oil painting of the New York City skyline",
samples: int = 5,
steps: int = 10,
batch_size: int = 1,
samples: int = 4,
prompt: str = "A princess riding on a pony",
batch_size: int = 4,
seed: int = None,
):
print(
f"prompt => {prompt}, steps => {steps}, samples => {samples}, batch_size => {batch_size}"
f"prompt => {prompt}",
f"samples => {samples}",
f"batch_size => {batch_size}",
f"seed => {seed}",
sep="\n",
)

dir = Path("/tmp/stable-diffusion")
if not dir.exists():
dir.mkdir(exist_ok=True, parents=True)
output_dir = Path("/tmp/stable-diffusion")
output_dir.mkdir(exist_ok=True, parents=True)

sd = StableDiffusion()
for i in range(samples):
t0 = time.time()
images = sd.run_inference.remote(prompt, steps, batch_size)
total_time = time.time() - t0
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
for j, image_bytes in enumerate(images):
output_path = dir / f"output_{j}_{i}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as f:
f.write(image_bytes)


# And this is our entrypoint; where the CLI is invoked. Explore CLI options
# with: `modal run stable_diffusion_cli.py --help`
#
# ## Performance
#
# This example can generate pictures in about a second, with startup time of about 10s for the first picture.
#
# See distribution of latencies below. This data was gathered by running 500 requests in sequence (meaning only
# the first request incurs a cold start). As you can see, the 90th percentile is 1.2s and the 99th percentile is 2.30s.
#
# ![latencies](./stable_diffusion_latencies.png)

for sample_idx in range(samples):
start = time.time()
images = sd.run_inference.remote(prompt, batch_size, seed)
duration = time.time() - start
print(f"Run {sample_idx+1} took {duration:.3f}s")
if sample_idx:
print(
f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image."
)
for batch_idx, image_bytes in enumerate(images):
output_path = (
output_dir
/ f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png"
)
if not batch_idx:
print("Saving outputs", end="\n\t")
print(
output_path,
end="\n" + ("\t" if batch_idx < len(images) - 1 else ""),
)
output_path.write_bytes(image_bytes)


def slugify(s: str) -> str:
return "".join(c if c.isalnum() else "-" for c in s).strip("-")

0 comments on commit 87237d0

Please sign in to comment.