Skip to content

Commit

Permalink
Add PixArt support (#731)
Browse files Browse the repository at this point in the history
# What does this PR do?

- [x] PixArt export support Alpha via CLI

* Tiny

```
optimum-cli export neuron --model hf-internal-testing/tiny-pixart-alpha-pipe --batch_size 1 --height 64 --width 64 --num_images_per_prompt 1 --torch_dtype bfloat16 --sequence_length 32 pixart_alpha_neuron_tiny/
```

* Regular

```
optimum-cli export neuron --model PixArt-alpha/PixArt-XL-2-512x512 --batch_size 1 --height 512 --width 512 --num_images_per_prompt 1 --torch_dtype bfloat16 --sequence_length 120 pixart_alpha_neuron_512/
```

- [x] PixArt export support Alpha via API

```python
import torch
from optimum.neuron import NeuronPixArtAlphaPipeline

# Compile
model_id = "PixArt-alpha/PixArt-XL-2-512x512"
compiler_args = {"auto_cast": "none"}
input_shapes = {"batch_size": 1, "height": 512, "width": 512, "sequence_length": 120}

neuron_model = NeuronPixArtAlphaPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True, disable_neuron_cache=True, **compiler_args, **input_shapes)

# Save locally or upload to the HuggingFace Hub
save_directory = "pixart_alpha_neuron_512/"
neuron_model.save_pretrained(save_directory)
```

- [x] Caching support
- [x] PixArt inference support

```python
from optimum.neuron import NeuronPixArtAlphaPipeline

# Inference
neuron_model = NeuronPixArtAlphaPipeline.from_pretrained("pixart_alpha_neuron_512/")
prompt = "An astronaut riding a green horse"
image = neuron_model(prompt=prompt).images[0]
image.save("out.png")
```

- [x] Tests
- [x] Documentation


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?

---------

Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
JingyaHuang and michaelbenayoun authored Dec 22, 2024
1 parent 62b6674 commit f44211c
Show file tree
Hide file tree
Showing 37 changed files with 622 additions and 209 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/doc-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ jobs:
with:
node-version: '18'
cache-dependency-path: "kit/package-lock.json"

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Set environment variables
run: |
Expand All @@ -46,7 +51,9 @@ jobs:
- name: Setup environment
run: |
pip install ".[quality, diffusers]"
python -m ensurepip --upgrade
python -m pip install --upgrade setuptools
python -m pip install ".[quality, diffusers]"
- name: Make documentation
shell: bash
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/doc-pr-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ jobs:
node-version: '18'
cache-dependency-path: "kit/package-lock.json"

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Setup environment
run: |
pip install --upgrade pip
pip install ".[quality, diffusers]"
- name: Make documentation
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/test_inf2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ jobs:
- name: Run other generation tests
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test --ignore=tests/generation/test_parallel.py tests/generation
- name: Run parallel tests
run: |
source aws_neuron_venv_pytorch/bin/activate
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


def get_node_results(node_url):

metrics = requests.get(node_url + "/metrics").text

counters = {
Expand Down
26 changes: 26 additions & 0 deletions docs/source/inference_tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -746,4 +746,30 @@ images[0].save("hug_lab.png")
alt="stable diffusion xl generated image with controlnet."
/>


## PixArt-α

### Compile

```bash
optimum-cli export neuron --model PixArt-alpha/PixArt-XL-2-512x512 --batch_size 1 --height 512 --width 512 --num_images_per_prompt 1 --torch_dtype bfloat16 --sequence_length 120 pixart_alpha_neuron_512/
```

### Text-to-Image

```python
from optimum.neuron import NeuronPixArtAlphaPipeline

neuron_model = NeuronPixArtAlphaPipeline.from_pretrained("pixart_alpha_neuron_512/")
prompt = "Oppenheimer sits on the beach on a chair, watching a nuclear exposition with a huge mushroom cloud, 120mm."
image = neuron_model(prompt=prompt).images[0]
```

<img
src="https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/pixart-alpha-oppenheimer.png"
width="256"
height="256"
alt="PixArt-α generated image."
/>

Are there any other stable diffusion features that you want us to support in 🤗`Optimum-neuron`? Please file an issue to [`Optimum-neuron` Github repo](https://github.com/huggingface/optimum-neuron) or discuss with us on [HuggingFace’s community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 !
5 changes: 5 additions & 0 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ The following Neuron model classes are available for stable diffusion tasks.
[[autodoc]] modeling_diffusion.NeuronStableDiffusionControlNetPipeline
- __call__

### NeuronPixArtAlphaPipeline

[[autodoc]] modeling_diffusion.NeuronPixArtAlphaPipeline
- __call__

### NeuronStableDiffusionXLPipeline

[[autodoc]] modeling_diffusion.NeuronStableDiffusionXLPipeline
Expand Down
1 change: 1 addition & 0 deletions docs/source/package_reference/supported_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ limitations under the License.
| Stable Diffusion XL Refiner | image-to-image, inpaint |
| SDXL Turbo | text-to-image, image-to-image, inpaint |
| LCM | text-to-image |
| PixArt-α | text-to-image |

## Sentence Transformers

Expand Down
8 changes: 8 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import os

os.environ["NEURON_FUSE_SOFTMAX"] = "1"
os.environ["NEURON_CUSTOM_SILU"] = "1"

if TYPE_CHECKING:
from argparse import ArgumentParser, Namespace, _SubParsersAction
Expand Down Expand Up @@ -112,6 +113,13 @@ def parse_args_neuronx(parser: "ArgumentParser"):
choices=["bf16", "fp16", "tf32"],
help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.',
)
optional_group.add_argument(
"--torch_dtype",
type=str,
default=None,
choices=["bfloat16", "float16", "float32"],
help="Override the default `torch.dtype` and load the model under this dtype. If `None` is passed, the dtype will be automatically derived from the model's weights.",
)
optional_group.add_argument(
"--tensor_parallel_size",
type=int,
Expand Down
74 changes: 53 additions & 21 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import torch
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig

Expand All @@ -29,13 +30,15 @@
DIFFUSION_MODEL_CONTROLNET_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_TRANSFORMER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
NEURON_FILE_NAME,
is_neuron_available,
is_neuronx_available,
map_torch_dtype,
)
from ...neuron.utils.misc import maybe_save_preprocessors
from ...neuron.utils.version_utils import (
Expand All @@ -50,8 +53,8 @@
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
check_mandatory_input_shapes,
get_diffusion_models_for_export,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
replace_stable_diffusion_submodels,
)

Expand Down Expand Up @@ -185,10 +188,10 @@ def normalize_stable_diffusion_input_shapes(
) -> Dict[str, Dict[str, int]]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args"))
# Remove `sequence_length` as diffusers will pad it to the max and remove number of channels.
mandatory_axes = mandatory_axes - {
"sequence_length",
"unet_num_channels",
"sequence_length", # `sequence_length` is optional, diffusers will pad it to the max if not provided.
# remove number of channels.
"unet_or_transformer_num_channels",
"vae_encoder_num_channels",
"vae_decoder_num_channels",
"num_images_per_prompt", # default to 1
Expand All @@ -199,6 +202,7 @@ def normalize_stable_diffusion_input_shapes(
)
mandatory_shapes = {name: args[name] for name in mandatory_axes}
mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1)
mandatory_shapes["sequence_length"] = args.get("sequence_length", None)
input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes)
return input_shapes

Expand All @@ -209,32 +213,45 @@ def infer_stable_diffusion_shapes_from_diffusers(
has_controlnets: bool,
):
if model.tokenizer is not None:
sequence_length = model.tokenizer.model_max_length
max_sequence_length = model.tokenizer.model_max_length
elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
sequence_length = model.tokenizer_2.model_max_length
max_sequence_length = model.tokenizer_2.model_max_length
else:
raise AttributeError(f"Cannot infer sequence_length from {type(model)} as there is no tokenizer as attribute.")
unet_num_channels = model.unet.config.in_channels
raise AttributeError(
f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute."
)
vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
height = input_shapes["unet"]["height"]
height = input_shapes["unet_or_transformer"]["height"]
scaled_height = height // vae_scale_factor
width = input_shapes["unet"]["width"]
width = input_shapes["unet_or_transformer"]["width"]
scaled_width = width // vae_scale_factor

input_shapes["text_encoder"].update({"sequence_length": sequence_length})
# Text encoders
if input_shapes["text_encoder"].get("sequence_length") is None:
input_shapes["text_encoder"].update({"sequence_length": max_sequence_length})
if hasattr(model, "text_encoder_2"):
input_shapes["text_encoder_2"] = input_shapes["text_encoder"]
input_shapes["unet"].update(

# UNet or Transformer
unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet"
unet_or_transformer_num_channels = getattr(model, unet_or_transformer_name).config.in_channels
input_shapes["unet_or_transformer"].update(
{
"sequence_length": sequence_length,
"num_channels": unet_num_channels,
"num_channels": unet_or_transformer_num_channels,
"height": scaled_height,
"width": scaled_width,
}
)
input_shapes["unet"]["vae_scale_factor"] = vae_scale_factor
if input_shapes["unet_or_transformer"].get("sequence_length") is None:
input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length
input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor
input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer")
if unet_or_transformer_name == "transformer":
input_shapes[unet_or_transformer_name]["encoder_hidden_size"] = model.text_encoder.config.hidden_size

# VAE
input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width})
input_shapes["vae_decoder"].update(
{"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width}
Expand All @@ -246,9 +263,9 @@ def infer_stable_diffusion_shapes_from_diffusers(
if hasattr(model, "text_encoder_2"):
encoder_hidden_size += model.text_encoder_2.config.hidden_size
input_shapes["controlnet"] = {
"batch_size": input_shapes["unet"]["batch_size"],
"sequence_length": sequence_length,
"num_channels": unet_num_channels,
"batch_size": input_shapes[unet_or_transformer_name]["batch_size"],
"sequence_length": input_shapes[unet_or_transformer_name]["sequence_length"],
"num_channels": unet_or_transformer_num_channels,
"height": scaled_height,
"width": scaled_width,
"vae_scale_factor": vae_scale_factor,
Expand Down Expand Up @@ -383,17 +400,20 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if getattr(model, "tokenizer_2", None) is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if getattr(model, "tokenizer_3", None) is not None:
model.tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
if getattr(model, "feature_extractor", None) is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales = _normalize_lora_params(
lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales
)
models_and_neuron_configs = get_stable_diffusion_models_for_export(
models_and_neuron_configs = get_diffusion_models_for_export(
pipeline=model,
text_encoder_input_shapes=input_shapes["text_encoder"],
unet_input_shapes=input_shapes["unet"],
unet_input_shapes=input_shapes.get("unet", None),
transformer_input_shapes=input_shapes.get("transformer", None),
vae_encoder_input_shapes=input_shapes["vae_encoder"],
vae_decoder_input_shapes=input_shapes["vae_decoder"],
dynamic_batch_size=dynamic_batch_size,
Expand All @@ -406,7 +426,6 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
controlnet_input_shapes=input_shapes.get("controlnet", None),
)
output_model_names = {
DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
}
Expand All @@ -418,6 +437,12 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
if getattr(model, "unet", None) is not None:
output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME)
if getattr(model, "transformer", None) is not None:
output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join(
DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME
)

# ControlNet models
if controlnet_ids:
Expand Down Expand Up @@ -488,6 +513,7 @@ def load_models_and_neuron_configs(
lora_weight_names: Optional[Union[str, List[str]]],
lora_adapter_names: Optional[Union[str, List[str]]],
lora_scales: Optional[Union[float, List[float]]],
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
controlnet_ids: Optional[Union[str, List[str]]] = None,
output_attentions: bool = False,
Expand All @@ -506,6 +532,7 @@ def load_models_and_neuron_configs(
"trust_remote_code": trust_remote_code,
"framework": "pt",
"library_name": library_name,
"torch_dtype": torch_dtype,
}
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
Expand Down Expand Up @@ -537,6 +564,7 @@ def main_export(
model_name_or_path: str,
output: Union[str, Path],
compiler_kwargs: Dict[str, Any],
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None,
task: str = "auto",
Expand Down Expand Up @@ -566,6 +594,7 @@ def main_export(
**input_shapes,
):
output = Path(output)
torch_dtype = map_torch_dtype(torch_dtype)
if not output.parent.exists():
output.parent.mkdir(parents=True)

Expand All @@ -579,6 +608,7 @@ def main_export(
model_name_or_path=model_name_or_path,
output=output,
model=model,
torch_dtype=torch_dtype,
tensor_parallel_size=tensor_parallel_size,
task=task,
dynamic_batch_size=dynamic_batch_size,
Expand All @@ -604,6 +634,7 @@ def main_export(
_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=output,
torch_dtype=torch_dtype,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
Expand Down Expand Up @@ -721,6 +752,7 @@ def main():
model_name_or_path=args.model,
output=args.output,
compiler_kwargs=compiler_kwargs,
torch_dtype=args.torch_dtype,
tensor_parallel_size=args.tensor_parallel_size,
task=task,
dynamic_batch_size=args.dynamic_batch_size,
Expand Down
Loading

0 comments on commit f44211c

Please sign in to comment.