-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TPU example for Ray Serve with Stable Diffusion
- Loading branch information
Showing
3 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
"""Ray Serve Stable Diffusion example.""" | ||
from io import BytesIO | ||
from typing import List | ||
from fastapi import FastAPI | ||
from fastapi.responses import Response | ||
import logging | ||
import ray | ||
from ray import serve | ||
import time | ||
|
||
app = FastAPI() | ||
_MAX_BATCH_SIZE = 64 | ||
|
||
logger = logging.getLogger("ray.serve") | ||
|
||
@serve.deployment(num_replicas=1) | ||
@serve.ingress(app) | ||
class APIIngress: | ||
def __init__(self, diffusion_model_handle) -> None: | ||
self.handle = diffusion_model_handle | ||
|
||
@app.get( | ||
"/imagine", | ||
responses={200: {"content": {"image/png": {}}}}, | ||
response_class=Response, | ||
) | ||
async def generate(self, prompt: str): | ||
assert len(prompt), "prompt parameter cannot be empty" | ||
|
||
image = await self.handle.generate.remote(prompt) | ||
return image | ||
|
||
|
||
@serve.deployment( | ||
ray_actor_options={ | ||
"resources": {"TPU": 4}, | ||
}, | ||
) | ||
class StableDiffusion: | ||
"""FLAX Stable Diffusion Ray Serve deployment running on TPUs. | ||
Attributes: | ||
run_with_profiler: Whether or not to run with the profiler. Note that | ||
this saves the profile to the separate TPU VM. | ||
""" | ||
|
||
def __init__( | ||
self, run_with_profiler: bool = False, warmup: bool = False, | ||
warmup_batch_size: int = _MAX_BATCH_SIZE): | ||
from diffusers import FlaxStableDiffusionPipeline | ||
from flax.jax_utils import replicate | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import pmap | ||
|
||
model_id = "CompVis/stable-diffusion-v1-4" | ||
|
||
self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
model_id, | ||
revision="bf16", | ||
dtype=jnp.bfloat16) | ||
|
||
self._p_params = replicate(params) | ||
self._p_generate = pmap(self._pipeline._generate) | ||
self._run_with_profiler = run_with_profiler | ||
self._profiler_dir = "/tmp/tensorboard" | ||
|
||
if warmup: | ||
logger.info("Sending warmup requests.") | ||
warmup_prompts = ["A warmup request"] * warmup_batch_size | ||
self.generate_tpu(warmup_prompts) | ||
|
||
def generate_tpu(self, prompts: List[str]): | ||
"""Generates a batch of images from Diffusion from a list of prompts. | ||
Args: | ||
prompts: a list of strings. Should be a factor of 4. | ||
Returns: | ||
A list of PIL Images. | ||
""" | ||
from flax.training.common_utils import shard | ||
import jax | ||
import numpy as np | ||
|
||
rng = jax.random.PRNGKey(0) | ||
rng = jax.random.split(rng, jax.device_count()) | ||
|
||
assert prompts, "prompt parameter cannot be empty" | ||
logger.info("Prompts: %s", prompts) | ||
prompt_ids = self._pipeline.prepare_inputs(prompts) | ||
prompt_ids = shard(prompt_ids) | ||
logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape) | ||
if self._run_with_profiler: | ||
jax.profiler.start_trace(self._profiler_dir) | ||
|
||
time_start = time.time() | ||
images = self._p_generate(prompt_ids, self._p_params, rng) | ||
images = images.block_until_ready() | ||
elapsed = time.time() - time_start | ||
if self._run_with_profiler: | ||
jax.profiler.stop_trace() | ||
|
||
logger.info("Inference time (in seconds): %f", elapsed) | ||
logger.info("Shape of the predictions: %s", images.shape) | ||
images = images.reshape( | ||
(images.shape[0] * images.shape[1],) + images.shape[-3:]) | ||
logger.info("Shape of images afterwards: %s", images.shape) | ||
return self._pipeline.numpy_to_pil(np.array(images)) | ||
|
||
@serve.batch(batch_wait_timeout_s=10, max_batch_size=_MAX_BATCH_SIZE) | ||
async def batched_generate_handler(self, prompts: List[str]): | ||
"""Sends a batch of prompts to the TPU model server. | ||
This takes advantage of @serve.batch, Ray Serve's built-in batching | ||
mechanism. | ||
Args: | ||
prompts: A list of input prompts | ||
Returns: | ||
A list of responses which contents are raw PNG. | ||
""" | ||
logger.info("Number of input prompts: %d", len(prompts)) | ||
num_to_pad = _MAX_BATCH_SIZE - len(prompts) | ||
prompts += ["Scratch request"] * num_to_pad | ||
|
||
images = self.generate_tpu(prompts) | ||
results = [] | ||
for image in images[: _MAX_BATCH_SIZE - num_to_pad]: | ||
file_stream = BytesIO() | ||
image.save(file_stream, "PNG") | ||
results.append( | ||
Response(content=file_stream.getvalue(), media_type="image/png") | ||
) | ||
return results | ||
|
||
async def generate(self, prompt): | ||
return await self.batched_generate_handler(prompt) | ||
|
||
|
||
diffusion_bound = StableDiffusion.bind() | ||
deployment = APIIngress.bind(diffusion_bound) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import argparse | ||
from concurrent import futures | ||
import functools | ||
from io import BytesIO | ||
import numpy as np | ||
from PIL import Image | ||
import requests | ||
from tqdm import tqdm | ||
|
||
|
||
_PROMPTS = [ | ||
"Labrador in the style of Hokusai", | ||
"Painting of a squirrel skating in New York", | ||
"HAL-9000 in the style of Van Gogh", | ||
"Times Square under water, with fish and a dolphin swimming around", | ||
"Ancient Roman fresco showing a man working on his laptop", | ||
"Armchair in the shape of an avocado", | ||
"Clown astronaut in space, with Earth in the background", | ||
"A cat sitting on a windowsill", | ||
"A dog playing fetch in a park", | ||
"A city skyline at night", | ||
"A field of flowers in bloom", | ||
"A tropical beach with palm trees", | ||
"A snowy mountain range", | ||
"A waterfall cascading into a pool", | ||
"A forest at sunset", | ||
"A desert landscape with cacti", | ||
"A volcano erupting", | ||
"A lightning storm in the distance", | ||
"A rainbow over a rainbow", | ||
"A unicorn grazing in a meadow", | ||
"A dragon flying through the sky", | ||
"A mermaid swimming in the ocean", | ||
"A robot walking down the street", | ||
"A UFO landing in a field", | ||
"A portal to another dimension", | ||
"A time traveler from the future", | ||
"A talking cat", | ||
"A bowl of fruit on a table", | ||
"A group of friends laughing", | ||
"A family sitting down for dinner", | ||
"A couple kissing in the rain", | ||
"A child playing with a toy", | ||
"A musician playing an instrument", | ||
"A painter painting a picture", | ||
"A writer writing a book", | ||
"A scientist conducting an experiment", | ||
"A construction worker building a house", | ||
"A doctor operating on a patient", | ||
"A teacher teaching a class", | ||
"A police officer arresting a suspect", | ||
"A firefighter putting out a fire", | ||
"A soldier fighting in a war", | ||
"A farmer working in a field", | ||
"A pilot flying a plane", | ||
"An astronaut in space", | ||
"A unicorn eating a rainbow" | ||
] | ||
|
||
|
||
def send_request_and_receive_image(prompt: str, url: str) -> BytesIO: | ||
"""Sends a single prompt request and returns the Image.""" | ||
try: | ||
inputs = "%20".join(prompt.split(" ")) | ||
resp = requests.get(f"{url}?prompt={inputs}") | ||
resp.raise_for_status() | ||
return BytesIO(resp.content) | ||
except requests.RequestException as e: | ||
print(f"An error occurred while sending the request: {e}") | ||
|
||
|
||
def image_grid(imgs, rows, cols): | ||
w, h = imgs[0].size | ||
grid = Image.new("RGB", size=(cols * w, rows * h)) | ||
for i, img in enumerate(imgs): | ||
grid.paste(img, box=(i % cols * w, i // cols * h)) | ||
return grid | ||
|
||
|
||
def send_requests(num_requests: int, batch_size: int, save_pictures: bool, | ||
url: str = "http://localhost:8000/imagine"): | ||
"""Sends a list of requests and processes the responses.""" | ||
print("num_requests: ", num_requests) | ||
print("batch_size: ", batch_size) | ||
print("url: ", url) | ||
print("save_pictures: ", save_pictures) | ||
|
||
prompts = _PROMPTS | ||
if num_requests > len(_PROMPTS): | ||
# Repeat until larger than num_requests | ||
prompts = _PROMPTS * int(np.ceil(num_requests / len(_PROMPTS))) | ||
|
||
prompts = np.random.choice( | ||
prompts, num_requests, replace=False) | ||
|
||
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor: | ||
raw_images = list( | ||
tqdm( | ||
executor.map( | ||
functools.partial(send_request_and_receive_image, url=url), | ||
prompts, | ||
), | ||
total=len(prompts), | ||
) | ||
) | ||
|
||
if save_pictures: | ||
print("Saving pictures to diffusion_results.png") | ||
images = [Image.open(raw_image) for raw_image in raw_images] | ||
grid = image_grid(images, 2, num_requests // 2) | ||
grid.save("./diffusion_results.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Sends requests to Diffusion.") | ||
parser.add_argument( | ||
"--num_requests", help="Number of requests to send.", | ||
default=8) | ||
parser.add_argument( | ||
"--batch_size", help="The number of requests to send at a time.", | ||
default=8) | ||
parser.add_argument( | ||
"--save_pictures", default=False, action="store_true", | ||
help="Whether to save the generated pictures to disk.") | ||
parser.add_argument( | ||
"--ip", help="The IP address to send the requests to.") | ||
|
||
args = parser.parse_args() | ||
|
||
send_requests( | ||
num_requests=int(args.num_requests), batch_size=int(args.batch_size), | ||
save_pictures=bool(args.save_pictures)) |