Skip to content

Commit

Permalink
(shortfin-sd) Usability and logging improvements. (#491)
Browse files Browse the repository at this point in the history
- Port configuration fixes
- Reduce need for topology footguns by implementing simple topology
config artifacts for 4 setups (cpx:single/multi, spx:single/multi).
These set server/device topologies to "known good" configurations.
- Fix example client script (send_request.py -> simple_client.py) and
include in python package, arg problems fixed as well (--save)
- Remove need for source code
- Safer failures for invalid output dims
- Don't report server startup under uvicorn.error
- Updates sd README with new example CLI inputs
- Adds help for client CLI args
  • Loading branch information
monorimet authored Nov 13, 2024
1 parent 9c5c8cc commit 2fa1f92
Show file tree
Hide file tree
Showing 14 changed files with 536 additions and 144 deletions.
12 changes: 4 additions & 8 deletions shortfin/python/shortfin/support/logging_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,15 @@ def __init__(self):
native_handler.setFormatter(NativeFormatter())

# TODO: Source from env vars.
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)
logger.addHandler(native_handler)


def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger:
"""Configures logging from a main entrypoint.
Returns a logger that can be used for the main module itself.
"""
logging.root.addHandler(native_handler)
logging.root.setLevel(logging.WARNING) # TODO: source from env vars
main_module = sys.modules["__main__"]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}")
logger.setLevel(logging.INFO)
logger.addHandler(native_handler)

return logger
return logging.getLogger(f"{main_module.__package__}.{module_suffix}")
11 changes: 4 additions & 7 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@ cd shortfin/
The server will prepare runtime artifacts for you.

```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single"
```
- Run with splat(empty) weights:
```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
```
- Run a request in a separate shell:

- Run a CLI client in a separate shell:
```
python shortfin/python/shortfin_apps/sd/examples/send_request.py --file=shortfin/python/shortfin_apps/sd/examples/sdxl_request.json
python -m shortfin_apps.sd.simple_client --interactive --save
```
10 changes: 7 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11022024"
ARTIFACT_VERSION = "11132024"
SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/"
)
Expand All @@ -51,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None):
return filter_by_model(mlir_filenames, model)


def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"):
def get_vmfb_filenames(
model_params: ModelParams, model=None, target: str = "amdgpu-gfx942"
):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
Expand Down Expand Up @@ -216,6 +218,8 @@ def sdxl(

mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"
if "gfx" in target:
target = "amdgpu-" + target

mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
Expand Down Expand Up @@ -247,7 +251,7 @@ def sdxl(
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
if needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames
Expand Down
123 changes: 123 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/config_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.build import *
from iree.build.executor import FileNamespace
import itertools
import os
import shortfin.array as sfnp
import copy

from shortfin_apps.sd.components.config_struct import ModelParams

this_dir = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(this_dir)

dtype_to_filetag = {
sfnp.float16: "fp16",
sfnp.float32: "fp32",
sfnp.int8: "i8",
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11132024"
SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/"


def get_url_map(filenames: list[str], bucket: str):
file_map = {}
for filename in filenames:
file_map[filename] = f"{bucket}{filename}"
return file_map


def needs_update(ctx):
stamp = ctx.allocate_file("version.txt")
stamp_path = stamp.get_fs_path()
if os.path.exists(stamp_path):
with open(stamp_path, "r") as s:
ver = s.read()
if ver != ARTIFACT_VERSION:
return True
else:
with open(stamp_path, "w") as s:
s.write(ARTIFACT_VERSION)
return True
return False


def needs_file(filename, ctx, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
# name_path = "bin" if namespace == FileNamespace.BIN else ""
# if name_path:
# filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
return needed


@entrypoint(description="Retreives a set of SDXL configuration files.")
def sdxlconfig(
target=cl_arg(
"target",
default="gfx942",
help="IREE target architecture.",
),
model=cl_arg("model", type=str, default="sdxl", help="Model architecture"),
topology=cl_arg(
"topology",
type=str,
default="spx_single",
help="System topology configfile keyword",
),
):
ctx = executor.BuildContext.current()
update = needs_update(ctx)

model_config_filenames = [f"{model}_config_i8.json"]
model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in model_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

topology_config_filenames = [f"topology_config_{topology}.txt"]
topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET)
for f, url in topology_config_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

flagfile_filenames = [f"{model}_flagfile_{target}.txt"]
flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET)
for f, url in flagfile_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

tuning_filenames = (
[f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else []
)
tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET)
for f, url in tuning_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [
*model_config_filenames,
*topology_config_filenames,
*flagfile_filenames,
*tuning_filenames,
]
return filenames


if __name__ == "__main__":
iree_build_main()
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .service import GenerateService
from .metrics import measure

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.generate")


class GenerateImageProcess(sf.Process):
Expand Down
7 changes: 7 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ def post_init(self):
raise ValueError("The rid should be a list.")
if self.output_type is None:
self.output_type = ["base64"] * self.num_output_images
# Temporary restrictions
heights = [self.height] if not isinstance(self.height, list) else self.height
widths = [self.width] if not isinstance(self.width, list) else self.width
if any(dim != 1024 for dim in [*heights, *widths]):
raise ValueError(
"Currently, only 1024x1024 output image size is supported."
)
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import shortfin as sf
from shortfin.interop.support.device_setup import get_selected_devices

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.manager")


class SystemManager:
Expand All @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
sb.visible_devices = sb.available_devices
sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
logging.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .io_struct import GenerateReqInput

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.messages")


class InferencePhase(Enum):
Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Callable, Any
import functools

logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.metrics")


def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"):
Expand Down
37 changes: 26 additions & 11 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from .metrics import measure


logger = logging.getLogger(__name__)
logger = logging.getLogger("shortfin-sd.service")
logger.setLevel(logging.DEBUG)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
Expand Down Expand Up @@ -119,26 +120,29 @@ def load_inference_parameters(

def start(self):
# Initialize programs.
# This can work if we only initialize one set of programs per service, as our programs
# in SDXL are stateless and
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]

for worker_idx, worker in enumerate(self.workers):
worker_devices = self.fibers[
worker_idx * (self.fibers_per_worker)
].raw_devices

logger.info(
f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}"
)
self.inference_programs[worker_idx][component] = sf.Program(
modules=component_modules,
devices=worker_devices,
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
logger.info("Program loaded.")

for worker_idx, worker in enumerate(self.workers):
self.inference_functions[worker_idx]["encode"] = {}
for bs in self.model_params.clip_batch_sizes:
Expand Down Expand Up @@ -170,7 +174,6 @@ def start(self):
] = self.inference_programs[worker_idx]["vae"][
f"{self.model_params.vae_module_name}.decode"
]
# breakpoint()
self.batcher.launch()

def shutdown(self):
Expand Down Expand Up @@ -212,8 +215,8 @@ class BatcherProcess(sf.Process):
into batches.
"""

STROBE_SHORT_DELAY = 0.1
STROBE_LONG_DELAY = 0.25
STROBE_SHORT_DELAY = 0.5
STROBE_LONG_DELAY = 1

def __init__(self, service: GenerateService):
super().__init__(fiber=service.fibers[0])
Expand Down Expand Up @@ -356,7 +359,6 @@ async def run(self):
logger.error("Executor process recieved disjoint batch.")
phase = req.phase
phases = self.exec_requests[0].phases

req_count = len(self.exec_requests)
device0 = self.service.fibers[self.fiber_index].device(0)
if phases[InferencePhase.PREPARE]["required"]:
Expand Down Expand Up @@ -424,8 +426,12 @@ async def _prepare(self, device, requests):
async def _encode(self, device, requests):
req_bs = len(requests)
entrypoints = self.service.inference_functions[self.worker_index]["encode"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._encode(device, [request])
return
for bs, fn in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

# Prepare tokenized input ids for CLIP inference
Expand Down Expand Up @@ -462,6 +468,7 @@ async def _encode(self, device, requests):
fn,
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
)
await device
pe, te = await fn(*clip_inputs, fiber=self.fiber)

for i in range(req_bs):
Expand All @@ -477,8 +484,12 @@ async def _denoise(self, device, requests):
cfg_mult = 2 if self.service.model_params.cfg_mode else 1
# Produce denoised latents
entrypoints = self.service.inference_functions[self.worker_index]["denoise"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._denoise(device, [request])
return
for bs, fns in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

# Get shape of batched latents.
Expand Down Expand Up @@ -613,8 +624,12 @@ async def _decode(self, device, requests):
req_bs = len(requests)
# Decode latents to images
entrypoints = self.service.inference_functions[self.worker_index]["decode"]
if req_bs not in list(entrypoints.keys()):
for request in requests:
await self._decode(device, [request])
return
for bs, fn in entrypoints.items():
if bs >= req_bs:
if bs == req_bs:
break

latents_shape = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
" a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo"
],
"neg_prompt": [
Expand Down
Loading

0 comments on commit 2fa1f92

Please sign in to comment.