Skip to content

Commit

Permalink
Add TPU example for Ray Serve with Stable Diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin85421 authored Jun 24, 2024
2 parents 5a10842 + f6f3577 commit d16752e
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
Binary file added stable_diffusion/diffusion_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
144 changes: 144 additions & 0 deletions stable_diffusion/stable_diffusion_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Ray Serve Stable Diffusion example."""
from io import BytesIO
from typing import List
from fastapi import FastAPI
from fastapi.responses import Response
import logging
import ray
from ray import serve
import time

app = FastAPI()
_MAX_BATCH_SIZE = 64

logger = logging.getLogger("ray.serve")

@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
def __init__(self, diffusion_model_handle) -> None:
self.handle = diffusion_model_handle

@app.get(
"/imagine",
responses={200: {"content": {"image/png": {}}}},
response_class=Response,
)
async def generate(self, prompt: str):
assert len(prompt), "prompt parameter cannot be empty"

image = await self.handle.generate.remote(prompt)
return image


@serve.deployment(
ray_actor_options={
"resources": {"TPU": 4},
},
)
class StableDiffusion:
"""FLAX Stable Diffusion Ray Serve deployment running on TPUs.
Attributes:
run_with_profiler: Whether or not to run with the profiler. Note that
this saves the profile to the separate TPU VM.
"""

def __init__(
self, run_with_profiler: bool = False, warmup: bool = False,
warmup_batch_size: int = _MAX_BATCH_SIZE):
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
import jax
import jax.numpy as jnp
from jax import pmap

model_id = "CompVis/stable-diffusion-v1-4"

self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision="bf16",
dtype=jnp.bfloat16)

self._p_params = replicate(params)
self._p_generate = pmap(self._pipeline._generate)
self._run_with_profiler = run_with_profiler
self._profiler_dir = "/tmp/tensorboard"

if warmup:
logger.info("Sending warmup requests.")
warmup_prompts = ["A warmup request"] * warmup_batch_size
self.generate_tpu(warmup_prompts)

def generate_tpu(self, prompts: List[str]):
"""Generates a batch of images from Diffusion from a list of prompts.
Args:
prompts: a list of strings. Should be a factor of 4.
Returns:
A list of PIL Images.
"""
from flax.training.common_utils import shard
import jax
import numpy as np

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())

assert prompts, "prompt parameter cannot be empty"
logger.info("Prompts: %s", prompts)
prompt_ids = self._pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape)
if self._run_with_profiler:
jax.profiler.start_trace(self._profiler_dir)

time_start = time.time()
images = self._p_generate(prompt_ids, self._p_params, rng)
images = images.block_until_ready()
elapsed = time.time() - time_start
if self._run_with_profiler:
jax.profiler.stop_trace()

logger.info("Inference time (in seconds): %f", elapsed)
logger.info("Shape of the predictions: %s", images.shape)
images = images.reshape(
(images.shape[0] * images.shape[1],) + images.shape[-3:])
logger.info("Shape of images afterwards: %s", images.shape)
return self._pipeline.numpy_to_pil(np.array(images))

@serve.batch(batch_wait_timeout_s=10, max_batch_size=_MAX_BATCH_SIZE)
async def batched_generate_handler(self, prompts: List[str]):
"""Sends a batch of prompts to the TPU model server.
This takes advantage of @serve.batch, Ray Serve's built-in batching
mechanism.
Args:
prompts: A list of input prompts
Returns:
A list of responses which contents are raw PNG.
"""
logger.info("Number of input prompts: %d", len(prompts))
num_to_pad = _MAX_BATCH_SIZE - len(prompts)
prompts += ["Scratch request"] * num_to_pad

images = self.generate_tpu(prompts)
results = []
for image in images[: _MAX_BATCH_SIZE - num_to_pad]:
file_stream = BytesIO()
image.save(file_stream, "PNG")
results.append(
Response(content=file_stream.getvalue(), media_type="image/png")
)
return results

async def generate(self, prompt):
return await self.batched_generate_handler(prompt)


diffusion_bound = StableDiffusion.bind()
deployment = APIIngress.bind(diffusion_bound)
132 changes: 132 additions & 0 deletions stable_diffusion/stable_diffusion_tpu_req.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import argparse
from concurrent import futures
import functools
from io import BytesIO
import numpy as np
from PIL import Image
import requests
from tqdm import tqdm


_PROMPTS = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
"A cat sitting on a windowsill",
"A dog playing fetch in a park",
"A city skyline at night",
"A field of flowers in bloom",
"A tropical beach with palm trees",
"A snowy mountain range",
"A waterfall cascading into a pool",
"A forest at sunset",
"A desert landscape with cacti",
"A volcano erupting",
"A lightning storm in the distance",
"A rainbow over a rainbow",
"A unicorn grazing in a meadow",
"A dragon flying through the sky",
"A mermaid swimming in the ocean",
"A robot walking down the street",
"A UFO landing in a field",
"A portal to another dimension",
"A time traveler from the future",
"A talking cat",
"A bowl of fruit on a table",
"A group of friends laughing",
"A family sitting down for dinner",
"A couple kissing in the rain",
"A child playing with a toy",
"A musician playing an instrument",
"A painter painting a picture",
"A writer writing a book",
"A scientist conducting an experiment",
"A construction worker building a house",
"A doctor operating on a patient",
"A teacher teaching a class",
"A police officer arresting a suspect",
"A firefighter putting out a fire",
"A soldier fighting in a war",
"A farmer working in a field",
"A pilot flying a plane",
"An astronaut in space",
"A unicorn eating a rainbow"
]


def send_request_and_receive_image(prompt: str, url: str) -> BytesIO:
"""Sends a single prompt request and returns the Image."""
try:
inputs = "%20".join(prompt.split(" "))
resp = requests.get(f"{url}?prompt={inputs}")
resp.raise_for_status()
return BytesIO(resp.content)
except requests.RequestException as e:
print(f"An error occurred while sending the request: {e}")


def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid


def send_requests(num_requests: int, batch_size: int, save_pictures: bool,
url: str = "http://localhost:8000/imagine"):
"""Sends a list of requests and processes the responses."""
print("num_requests: ", num_requests)
print("batch_size: ", batch_size)
print("url: ", url)
print("save_pictures: ", save_pictures)

prompts = _PROMPTS
if num_requests > len(_PROMPTS):
# Repeat until larger than num_requests
prompts = _PROMPTS * int(np.ceil(num_requests / len(_PROMPTS)))

prompts = np.random.choice(
prompts, num_requests, replace=False)

with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
raw_images = list(
tqdm(
executor.map(
functools.partial(send_request_and_receive_image, url=url),
prompts,
),
total=len(prompts),
)
)

if save_pictures:
print("Saving pictures to diffusion_results.png")
images = [Image.open(raw_image) for raw_image in raw_images]
grid = image_grid(images, 2, num_requests // 2)
grid.save("./diffusion_results.png")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sends requests to Diffusion.")
parser.add_argument(
"--num_requests", help="Number of requests to send.",
default=8)
parser.add_argument(
"--batch_size", help="The number of requests to send at a time.",
default=8)
parser.add_argument(
"--save_pictures", default=False, action="store_true",
help="Whether to save the generated pictures to disk.")
parser.add_argument(
"--ip", help="The IP address to send the requests to.")

args = parser.parse_args()

send_requests(
num_requests=int(args.num_requests), batch_size=int(args.batch_size),
save_pictures=bool(args.save_pictures))

0 comments on commit d16752e

Please sign in to comment.