diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml index 29b67559b..c4f67a17d 100644 --- a/.github/workflows/doc-build.yml +++ b/.github/workflows/doc-build.yml @@ -29,6 +29,11 @@ jobs: with: node-version: '18' cache-dependency-path: "kit/package-lock.json" + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' - name: Set environment variables run: | @@ -46,7 +51,9 @@ jobs: - name: Setup environment run: | - pip install ".[quality, diffusers]" + python -m ensurepip --upgrade + python -m pip install --upgrade setuptools + python -m pip install ".[quality, diffusers]" - name: Make documentation shell: bash diff --git a/.github/workflows/doc-pr-build.yml b/.github/workflows/doc-pr-build.yml index a206771b5..877cf5900 100644 --- a/.github/workflows/doc-pr-build.yml +++ b/.github/workflows/doc-pr-build.yml @@ -29,8 +29,14 @@ jobs: node-version: '18' cache-dependency-path: "kit/package-lock.json" + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Setup environment run: | + pip install --upgrade pip pip install ".[quality, diffusers]" - name: Make documentation diff --git a/.github/workflows/test_inf2.yml b/.github/workflows/test_inf2.yml index d71e47670..7135c8d7d 100644 --- a/.github/workflows/test_inf2.yml +++ b/.github/workflows/test_inf2.yml @@ -60,4 +60,8 @@ jobs: - name: Run other generation tests run: | source aws_neuron_venv_pytorch/bin/activate - HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test --ignore=tests/generation/test_parallel.py tests/generation + - name: Run parallel tests + run: | + source aws_neuron_venv_pytorch/bin/activate + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation/test_parallel.py diff --git a/benchmark/text-generation-inference/performance/tgi_live_metrics.py b/benchmark/text-generation-inference/performance/tgi_live_metrics.py index d988b857e..cf0b0d2d5 100644 --- a/benchmark/text-generation-inference/performance/tgi_live_metrics.py +++ b/benchmark/text-generation-inference/performance/tgi_live_metrics.py @@ -3,7 +3,6 @@ def get_node_results(node_url): - metrics = requests.get(node_url + "/metrics").text counters = { diff --git a/docs/source/inference_tutorials/stable_diffusion.mdx b/docs/source/inference_tutorials/stable_diffusion.mdx index b0e614504..01ff55394 100644 --- a/docs/source/inference_tutorials/stable_diffusion.mdx +++ b/docs/source/inference_tutorials/stable_diffusion.mdx @@ -746,4 +746,30 @@ images[0].save("hug_lab.png") alt="stable diffusion xl generated image with controlnet." /> + +## PixArt-α + +### Compile + +```bash +optimum-cli export neuron --model PixArt-alpha/PixArt-XL-2-512x512 --batch_size 1 --height 512 --width 512 --num_images_per_prompt 1 --torch_dtype bfloat16 --sequence_length 120 pixart_alpha_neuron_512/ +``` + +### Text-to-Image + +```python +from optimum.neuron import NeuronPixArtAlphaPipeline + +neuron_model = NeuronPixArtAlphaPipeline.from_pretrained("pixart_alpha_neuron_512/") +prompt = "Oppenheimer sits on the beach on a chair, watching a nuclear exposition with a huge mushroom cloud, 120mm." +image = neuron_model(prompt=prompt).images[0] +``` + +PixArt-α generated image. + Are there any other stable diffusion features that you want us to support in 🤗`Optimum-neuron`? Please file an issue to [`Optimum-neuron` Github repo](https://github.com/huggingface/optimum-neuron) or discuss with us on [HuggingFace’s community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 ! diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index 76d8e2fa9..863887173 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -132,6 +132,11 @@ The following Neuron model classes are available for stable diffusion tasks. [[autodoc]] modeling_diffusion.NeuronStableDiffusionControlNetPipeline - __call__ +### NeuronPixArtAlphaPipeline + +[[autodoc]] modeling_diffusion.NeuronPixArtAlphaPipeline + - __call__ + ### NeuronStableDiffusionXLPipeline [[autodoc]] modeling_diffusion.NeuronStableDiffusionXLPipeline diff --git a/docs/source/package_reference/supported_models.mdx b/docs/source/package_reference/supported_models.mdx index 98064fb5d..d44c2e613 100644 --- a/docs/source/package_reference/supported_models.mdx +++ b/docs/source/package_reference/supported_models.mdx @@ -74,6 +74,7 @@ limitations under the License. | Stable Diffusion XL Refiner | image-to-image, inpaint | | SDXL Turbo | text-to-image, image-to-image, inpaint | | LCM | text-to-image | +| PixArt-α | text-to-image | ## Sentence Transformers diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index e14dd99df..86dd22ef1 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -28,6 +28,7 @@ import os os.environ["NEURON_FUSE_SOFTMAX"] = "1" + os.environ["NEURON_CUSTOM_SILU"] = "1" if TYPE_CHECKING: from argparse import ArgumentParser, Namespace, _SubParsersAction @@ -112,6 +113,13 @@ def parse_args_neuronx(parser: "ArgumentParser"): choices=["bf16", "fp16", "tf32"], help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.', ) + optional_group.add_argument( + "--torch_dtype", + type=str, + default=None, + choices=["bfloat16", "float16", "float32"], + help="Override the default `torch.dtype` and load the model under this dtype. If `None` is passed, the dtype will be automatically derived from the model's weights.", + ) optional_group.add_argument( "--tensor_parallel_size", type=int, diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index d458ded0d..3e2d00ef6 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -21,6 +21,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import torch from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, AutoTokenizer, PretrainedConfig @@ -29,6 +30,7 @@ DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, + DIFFUSION_MODEL_TRANSFORMER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, DIFFUSION_MODEL_VAE_ENCODER_NAME, @@ -36,6 +38,7 @@ NEURON_FILE_NAME, is_neuron_available, is_neuronx_available, + map_torch_dtype, ) from ...neuron.utils.misc import maybe_save_preprocessors from ...neuron.utils.version_utils import ( @@ -50,8 +53,8 @@ from .utils import ( build_stable_diffusion_components_mandatory_shapes, check_mandatory_input_shapes, + get_diffusion_models_for_export, get_encoder_decoder_models_for_export, - get_stable_diffusion_models_for_export, replace_stable_diffusion_submodels, ) @@ -185,10 +188,10 @@ def normalize_stable_diffusion_input_shapes( ) -> Dict[str, Dict[str, int]]: args = vars(args) if isinstance(args, argparse.Namespace) else args mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args")) - # Remove `sequence_length` as diffusers will pad it to the max and remove number of channels. mandatory_axes = mandatory_axes - { - "sequence_length", - "unet_num_channels", + "sequence_length", # `sequence_length` is optional, diffusers will pad it to the max if not provided. + # remove number of channels. + "unet_or_transformer_num_channels", "vae_encoder_num_channels", "vae_decoder_num_channels", "num_images_per_prompt", # default to 1 @@ -199,6 +202,7 @@ def normalize_stable_diffusion_input_shapes( ) mandatory_shapes = {name: args[name] for name in mandatory_axes} mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1) + mandatory_shapes["sequence_length"] = args.get("sequence_length", None) input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes) return input_shapes @@ -209,32 +213,45 @@ def infer_stable_diffusion_shapes_from_diffusers( has_controlnets: bool, ): if model.tokenizer is not None: - sequence_length = model.tokenizer.model_max_length + max_sequence_length = model.tokenizer.model_max_length elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None: - sequence_length = model.tokenizer_2.model_max_length + max_sequence_length = model.tokenizer_2.model_max_length else: - raise AttributeError(f"Cannot infer sequence_length from {type(model)} as there is no tokenizer as attribute.") - unet_num_channels = model.unet.config.in_channels + raise AttributeError( + f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute." + ) vae_encoder_num_channels = model.vae.config.in_channels vae_decoder_num_channels = model.vae.config.latent_channels vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8 - height = input_shapes["unet"]["height"] + height = input_shapes["unet_or_transformer"]["height"] scaled_height = height // vae_scale_factor - width = input_shapes["unet"]["width"] + width = input_shapes["unet_or_transformer"]["width"] scaled_width = width // vae_scale_factor - input_shapes["text_encoder"].update({"sequence_length": sequence_length}) + # Text encoders + if input_shapes["text_encoder"].get("sequence_length") is None: + input_shapes["text_encoder"].update({"sequence_length": max_sequence_length}) if hasattr(model, "text_encoder_2"): input_shapes["text_encoder_2"] = input_shapes["text_encoder"] - input_shapes["unet"].update( + + # UNet or Transformer + unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet" + unet_or_transformer_num_channels = getattr(model, unet_or_transformer_name).config.in_channels + input_shapes["unet_or_transformer"].update( { - "sequence_length": sequence_length, - "num_channels": unet_num_channels, + "num_channels": unet_or_transformer_num_channels, "height": scaled_height, "width": scaled_width, } ) - input_shapes["unet"]["vae_scale_factor"] = vae_scale_factor + if input_shapes["unet_or_transformer"].get("sequence_length") is None: + input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length + input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor + input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer") + if unet_or_transformer_name == "transformer": + input_shapes[unet_or_transformer_name]["encoder_hidden_size"] = model.text_encoder.config.hidden_size + + # VAE input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width}) input_shapes["vae_decoder"].update( {"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width} @@ -246,9 +263,9 @@ def infer_stable_diffusion_shapes_from_diffusers( if hasattr(model, "text_encoder_2"): encoder_hidden_size += model.text_encoder_2.config.hidden_size input_shapes["controlnet"] = { - "batch_size": input_shapes["unet"]["batch_size"], - "sequence_length": sequence_length, - "num_channels": unet_num_channels, + "batch_size": input_shapes[unet_or_transformer_name]["batch_size"], + "sequence_length": input_shapes[unet_or_transformer_name]["sequence_length"], + "num_channels": unet_or_transformer_num_channels, "height": scaled_height, "width": scaled_width, "vae_scale_factor": vae_scale_factor, @@ -383,6 +400,8 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( model.tokenizer.save_pretrained(output.joinpath("tokenizer")) if getattr(model, "tokenizer_2", None) is not None: model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) + if getattr(model, "tokenizer_3", None) is not None: + model.tokenizer_3.save_pretrained(output.joinpath("tokenizer_3")) if getattr(model, "feature_extractor", None) is not None: model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) model.save_config(output) @@ -390,10 +409,11 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales = _normalize_lora_params( lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales ) - models_and_neuron_configs = get_stable_diffusion_models_for_export( + models_and_neuron_configs = get_diffusion_models_for_export( pipeline=model, text_encoder_input_shapes=input_shapes["text_encoder"], - unet_input_shapes=input_shapes["unet"], + unet_input_shapes=input_shapes.get("unet", None), + transformer_input_shapes=input_shapes.get("transformer", None), vae_encoder_input_shapes=input_shapes["vae_encoder"], vae_decoder_input_shapes=input_shapes["vae_decoder"], dynamic_batch_size=dynamic_batch_size, @@ -406,7 +426,6 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( controlnet_input_shapes=input_shapes.get("controlnet", None), ) output_model_names = { - DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME), DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME), DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME), } @@ -418,6 +437,12 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join( DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME ) + if getattr(model, "unet", None) is not None: + output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME) + if getattr(model, "transformer", None) is not None: + output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join( + DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME + ) # ControlNet models if controlnet_ids: @@ -488,6 +513,7 @@ def load_models_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]], lora_adapter_names: Optional[Union[str, List[str]]], lora_scales: Optional[Union[float, List[float]]], + torch_dtype: Optional[Union[str, torch.dtype]] = None, tensor_parallel_size: int = 1, controlnet_ids: Optional[Union[str, List[str]]] = None, output_attentions: bool = False, @@ -506,6 +532,7 @@ def load_models_and_neuron_configs( "trust_remote_code": trust_remote_code, "framework": "pt", "library_name": library_name, + "torch_dtype": torch_dtype, } if model is None: model = TasksManager.get_model_from_task(**model_kwargs) @@ -537,6 +564,7 @@ def main_export( model_name_or_path: str, output: Union[str, Path], compiler_kwargs: Dict[str, Any], + torch_dtype: Optional[Union[str, torch.dtype]] = None, tensor_parallel_size: int = 1, model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None, task: str = "auto", @@ -566,6 +594,7 @@ def main_export( **input_shapes, ): output = Path(output) + torch_dtype = map_torch_dtype(torch_dtype) if not output.parent.exists(): output.parent.mkdir(parents=True) @@ -579,6 +608,7 @@ def main_export( model_name_or_path=model_name_or_path, output=output, model=model, + torch_dtype=torch_dtype, tensor_parallel_size=tensor_parallel_size, task=task, dynamic_batch_size=dynamic_batch_size, @@ -604,6 +634,7 @@ def main_export( _, neuron_outputs = export_models( models_and_neuron_configs=models_and_neuron_configs, output_dir=output, + torch_dtype=torch_dtype, disable_neuron_cache=disable_neuron_cache, compiler_workdir=compiler_workdir, inline_weights_to_neff=inline_weights_to_neff, @@ -721,6 +752,7 @@ def main(): model_name_or_path=args.model, output=args.output, compiler_kwargs=compiler_kwargs, + torch_dtype=args.torch_dtype, tensor_parallel_size=args.tensor_parallel_size, task=task, dynamic_batch_size=args.dynamic_batch_size, diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 948b8e830..62e0d91a8 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -21,9 +21,10 @@ import torch +from optimum.utils import logging + from ...exporters.base import ExportConfig from ...neuron.utils import is_neuron_available, is_transformers_neuronx_available -from ...utils import logging if TYPE_CHECKING: @@ -168,9 +169,8 @@ def __init__( encoder_hidden_size: Optional[int] = None, output_attentions: bool = False, output_hidden_states: bool = False, - # TODO: add custom dtype after optimum 1.13 release - # int_dtype: str = "int64", - # float_dtype: str = "fp32", + int_dtype: Union[str, torch.dtype] = "int64", + float_dtype: Union[str, torch.dtype] = "fp32", ): self._config = config self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) @@ -179,6 +179,8 @@ def __init__( self.task = task self._axes: Dict[str, int] = {} self.dynamic_batch_size = dynamic_batch_size + self.int_dtype = int_dtype + self.float_dtype = float_dtype if self.dynamic_batch_size is True and is_neuron_available(): logger.info("Overwriting batch size to 1 for neuron dynamic batch size support.") @@ -316,9 +318,15 @@ def generate_dummy_inputs( input_was_inserted = False for dummy_input_gen in dummy_inputs_generators: if dummy_input_gen.supports_input(input_name): - dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="pt") - # TODO: add custom dtype after optimum 1.13 release - # dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="pt", int_dtype=self.int_dtype, float_dtype=self.float_dtype) + # TODO: remove the mapper and use directly torch float dtype after the PR in Optimum makes its way to a release: https://github.com/huggingface/optimum/pull/2117 + mapper = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"} + if isinstance(self.float_dtype, torch.dtype): + float_dtype = mapper[self.float_dtype] + else: + float_dtype = self.float_dtype + dummy_inputs[input_name] = dummy_input_gen.generate( + input_name, framework="pt", int_dtype=self.int_dtype, float_dtype=float_dtype + ) input_was_inserted = True break if not input_was_inserted: diff --git a/optimum/exporters/neuron/config.py b/optimum/exporters/neuron/config.py index a5dd4202b..82e842954 100644 --- a/optimum/exporters/neuron/config.py +++ b/optimum/exporters/neuron/config.py @@ -88,61 +88,6 @@ class TextSeq2SeqNeuronConfig(NeuronDefaultConfig): DummySeq2SeqPastKeyValuesGenerator, ) - @property - def inputs(self) -> List[str]: - common_inputs = [] - # encoder + decoder without past - if "encoder" in self.MODEL_TYPE: - common_inputs = ["input_ids", "attention_mask"] - # decoder with past - if "decoder" in self.MODEL_TYPE: - common_inputs = [ - "decoder_input_ids", - "decoder_attention_mask", - "encoder_hidden_states", - "attention_mask", # TODO: replace with `encoder_attention_mask` after optimum 1.14 release - ] - - return common_inputs - - @property - def outputs(self) -> List[str]: - common_outputs = [] - # encoder + decoder without past - if "encoder" in self.MODEL_TYPE: - common_outputs = ( - [f"present.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] - + [f"present.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] - + [f"present.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] - + [f"present.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] - ) - # decoder with past - if "decoder" in self.MODEL_TYPE: - beam_outputs = ( - ["next_token_scores", "next_tokens", "next_indices"] if self.num_beams > 1 else ["next_tokens"] - ) - common_outputs = ( - beam_outputs - + [f"past.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] - + [f"past.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] - + [f"past.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] - + [f"past.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] - ) - - if self.output_hidden_states: - # Flatten hidden states of all layers - common_outputs += [ - f"decoder_hidden_state.{idx}" for idx in range(self._config.num_decoder_layers + 1) - ] # +1 for the embedding layer - - if self.output_attentions: - # Flatten attentions tensors of all attention layers - common_outputs += [f"decoder_attention.{idx}" for idx in range(self._config.num_decoder_layers)] - if getattr(self._config, "is_encoder_decoder", False) is True: - common_outputs += [f"cross_attention.{idx}" for idx in range(self._config.num_decoder_layers)] - - return common_outputs - def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index af13e9012..b9b1a79c5 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -106,7 +106,7 @@ def validate_models_outputs( """ if len(neuron_named_outputs) != len(models_and_neuron_configs.keys()): raise ValueError( - f"Invalid number of Neuron named outputs. Required {len(models_and_neuron_configs.keys())}, Provided {len(neuron_named_outputs)}" + f"Invalid number of Neuron named outputs. Required {models_and_neuron_configs.keys()}, Provided {neuron_named_outputs.keys()}" ) if neuron_named_outputs is not None and len(neuron_named_outputs) != len(models_and_neuron_configs): @@ -186,7 +186,6 @@ def validate_model_outputs( inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes) ref_inputs = config.unflatten_inputs(inputs) if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False): - reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes) if "SentenceTransformer" in reference_model.__class__.__name__: reference_model = config.patch_model_for_export(reference_model, ref_inputs) @@ -199,7 +198,9 @@ def validate_model_outputs( ref_inputs = tuple(ref_inputs.values()) ref_outputs = reference_model(*ref_inputs) neuron_inputs = tuple(inputs.values()) - elif "controlnet" in getattr(config._config, "_class_name", "").lower(): + elif any( + pattern in getattr(config._config, "_class_name", "").lower() for pattern in ["controlnet", "transformer"] + ): reference_model = config.patch_model_for_export(reference_model, ref_inputs) neuron_inputs = ref_inputs = tuple(ref_inputs.values()) ref_outputs = reference_model(*ref_inputs) @@ -248,14 +249,14 @@ def validate_model_outputs( value_failures = [] for i, (name, neuron_output) in enumerate(zip(neuron_output_names_list, neuron_outputs)): if isinstance(neuron_output, torch.Tensor): - ref_output = ref_outputs[name].numpy() if isinstance(ref_outputs, dict) else ref_outputs[i].numpy() - neuron_output = neuron_output.numpy() + ref_output = ref_outputs[name] if isinstance(ref_outputs, dict) else ref_outputs[i] + neuron_output = neuron_output elif isinstance(neuron_output, tuple): # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors; - ref_output = torch.stack(ref_outputs[name]).numpy() - neuron_output = torch.stack(neuron_output).numpy() + ref_output = torch.stack(ref_outputs[name]) + neuron_output = torch.stack(neuron_output) elif isinstance(neuron_output, list): - ref_output = [output.numpy() for output in ref_outputs[name]] - neuron_output = [output.numpy() for output in neuron_output] + ref_output = ref_outputs[name] + neuron_output = neuron_output logger.info(f'\t- Validating Neuron Model output "{name}":') @@ -272,8 +273,8 @@ def validate_model_outputs( logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}") # Values - if not np.allclose(ref_output, output, atol=atol): - max_diff = np.amax(np.abs(ref_output - output)) + if not torch.allclose(ref_output, output.to(ref_output.dtype), atol=atol): + max_diff = torch.max(torch.abs(ref_output - output)) logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})") value_failures.append((name, max_diff)) else: @@ -296,6 +297,7 @@ def export_models( str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"] ], output_dir: Path, + torch_dtype: Optional[Union[str, torch.dtype]] = None, disable_neuron_cache: Optional[bool] = False, compiler_workdir: Optional[Path] = None, inline_weights_to_neff: bool = True, @@ -312,6 +314,8 @@ def export_models( A dictionnary containing the models to export and their corresponding neuron configs. output_dir (`Path`): Output directory to store the exported Neuron models. + torch_dtype (`Optional[Union[str, torch.dtype]]`, defaults to `None`): + Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the dtype will be automatically derived from the model's weights. disable_neuron_cache (`Optional[bool]`, defaults to `False`): Whether to disable automatic caching of AOT compiled models (not applicable for JIT compilation). compiler_workdir (`Optional[Path]`, defaults to `None`): @@ -534,7 +538,7 @@ def export_neuronx( tensor_parallel_size = config.tensor_parallel_size if isinstance(config, TextSeq2SeqNeuronConfig): checked_model = config.patch_model_for_export(model_or_path, **input_shapes) - if tensor_parallel_size == 1: + if tensor_parallel_size == 1 and hasattr(config, "generate_io_aliases"): aliases = config.generate_io_aliases(checked_model) else: checked_model = config.patch_model_for_export(model_or_path, dummy_inputs) @@ -608,24 +612,27 @@ def add_stable_diffusion_compiler_args(config, compiler_args): sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder", "controlnet"] if any(component in identifier for component in sd_components): compiler_args.append("--enable-fast-loading-neuron-binaries") - # unet or controlnet - if "unet" in identifier or "controlnet" in identifier: + # unet or transformer or controlnet + if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]): # SDXL unet doesn't support fast loading neuron binaries(sdk 2.19.1) if not getattr(config, "is_sdxl", False): compiler_args.append("--enable-fast-loading-neuron-binaries") - compiler_args.append("--model-type=unet-inference") + if "unet" in identifier or "controlnet" in identifier: + compiler_args.append("--model-type=unet-inference") + if "transformer" in identifier: + compiler_args.append("--model-type=transformer") return compiler_args def improve_stable_diffusion_loading(config, neuron_model): - # Combine the model name and its path to identify which is the subcomponent in Stable Diffusion pipeline + # Combine the model name and its path to identify which is the subcomponent in Diffusion pipeline identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "") identifier = identifier.lower() - sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder", "controlnet"] + sd_components = ["text_encoder", "unet", "transformer", "vae", "vae_encoder", "vae_decoder", "controlnet"] if any(component in identifier for component in sd_components): neuronx.async_load(neuron_model) # unet - if "unet" in identifier or "controlnet" in identifier: + if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]): neuronx.lazy_load(neuron_model) diff --git a/optimum/exporters/neuron/model_configs/traced_configs.py b/optimum/exporters/neuron/model_configs/traced_configs.py index c936f00f8..fc1005a3c 100644 --- a/optimum/exporters/neuron/model_configs/traced_configs.py +++ b/optimum/exporters/neuron/model_configs/traced_configs.py @@ -54,9 +54,11 @@ from ..model_wrappers import ( ControlNetNeuronWrapper, NoCacheModelWrapper, + PixartTransformerNeuronWrapper, SentenceTransformersCLIPNeuronWrapper, SentenceTransformersTransformerNeuronWrapper, T5DecoderWrapper, + T5EncoderForSeq2SeqLMWrapper, T5EncoderWrapper, UnetNeuronWrapper, ) @@ -666,6 +668,49 @@ def with_controlnet(self, with_controlnet: bool): self._with_controlnet = with_controlnet +@register_in_tasks_manager("pixart-transformer-2d", *["semantic-segmentation"], library_name="diffusers") +class PixartTransformerNeuronConfig(VisionNeuronConfig): + ATOL_FOR_VALIDATION = 1e-3 + INPUT_ARGS = ( + "batch_size", + "sequence_length", + "num_channels", + "width", + "height", + "vae_scale_factor", + "encoder_hidden_size", + ) + MODEL_TYPE = "pixart-transformer-2d" + CUSTOM_MODEL_WRAPPER = PixartTransformerNeuronWrapper + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + height="height", + width="width", + num_channels="in_channels", + hidden_size="cross_attention_dim", + vocab_size="norm_num_groups", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyVisionInputGenerator, + DummyControNetInputGenerator, + DummyTextInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, + ) + + @property + def inputs(self) -> List[str]: + common_inputs = ["sample", "encoder_hidden_states", "timestep", "encoder_attention_mask"] + return common_inputs + + @property + def outputs(self) -> List[str]: + return ["out_hidden_states"] + + def patch_model_for_export(self, model, dummy_inputs): + return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys())) + + @register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers") class ControlNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 @@ -768,12 +813,8 @@ def patch_model_for_export( return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True) -@register_in_tasks_manager("t5-encoder", "text2text-generation") -class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig): +class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 - INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") - MODEL_TYPE = "t5-encoder" - CUSTOM_MODEL_WRAPPER = T5EncoderWrapper NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( hidden_size="d_model", num_attention_heads="num_heads", @@ -784,8 +825,38 @@ class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig): ) @property - def is_decoder(self) -> bool: - return False + def inputs(self) -> List[str]: + return ["input_ids", "attention_mask"] + + +@register_in_tasks_manager("t5", *["feature-extraction"], library_name="diffusers") +class T5EncoderForDiffusersNeuronConfig(T5EncoderBaseNeuronConfig): + CUSTOM_MODEL_WRAPPER = T5EncoderWrapper + INPUT_ARGS = ("batch_size", "sequence_length") + + @property + def outputs(self) -> List[str]: + return ["last_hidden_state"] + + def patch_model_for_export(self, model_or_path, **input_shapes): + return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes) + + +@register_in_tasks_manager("t5-encoder", *["text2text-generation"]) +class T5EncoderForTransformersNeuronConfig(T5EncoderBaseNeuronConfig): + CUSTOM_MODEL_WRAPPER = T5EncoderForSeq2SeqLMWrapper + INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") + MODEL_TYPE = "t5-encoder" + + @property + def outputs(self) -> List[str]: + common_outputs = ( + [f"present.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] + + [f"present.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] + + [f"present.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] + + [f"present.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] + ) + return common_outputs def patch_model_for_export(self, model_or_path, device="xla", **kwargs): num_beams = kwargs.pop("num_beams", 1) @@ -861,15 +932,43 @@ class T5DecoderNeuronConfig(TextSeq2SeqNeuronConfig): CUSTOM_MODEL_WRAPPER = T5DecoderWrapper NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig - @property - def is_decoder(self) -> bool: - return True - @property def inputs(self) -> List[str]: - common_inputs = super().inputs + ["beam_idx", "beam_scores"] + common_inputs = [ + "decoder_input_ids", + "decoder_attention_mask", + "encoder_hidden_states", + "attention_mask", # TODO: replace with `encoder_attention_mask` after optimum 1.14 release + "beam_idx", + "beam_scores", + ] return common_inputs + @property + def outputs(self) -> List[str]: + beam_outputs = ["next_token_scores", "next_tokens", "next_indices"] if self.num_beams > 1 else ["next_tokens"] + common_outputs = ( + beam_outputs + + [f"past.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] + + [f"past.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] + + [f"past.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] + + [f"past.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] + ) + + if self.output_hidden_states: + # Flatten hidden states of all layers + common_outputs += [ + f"decoder_hidden_state.{idx}" for idx in range(self._config.num_decoder_layers + 1) + ] # +1 for the embedding layer + + if self.output_attentions: + # Flatten attentions tensors of all attention layers + common_outputs += [f"decoder_attention.{idx}" for idx in range(self._config.num_decoder_layers)] + if getattr(self._config, "is_encoder_decoder", False) is True: + common_outputs += [f"cross_attention.{idx}" for idx in range(self._config.num_decoder_layers)] + + return common_outputs + def generate_dummy_inputs(self, **kwargs): batch_size = kwargs.pop("batch_size") * kwargs.get("num_beams") dummy_inputs = super().generate_dummy_inputs(batch_size=batch_size, **kwargs) diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index d451c2b88..f1fb2995e 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -79,6 +79,40 @@ def forward(self, *inputs): return out_tuple +class PixartTransformerNeuronWrapper(torch.nn.Module): + def __init__(self, model, input_names: List[str]): + super().__init__() + self.model = model + self.dtype = model.dtype + self.input_names = input_names + + def forward(self, *inputs): + if len(inputs) != len(self.input_names): + raise ValueError( + f"The model needs {len(self.input_names)} inputs: {self.input_names}." + f" But only {len(input)} inputs are passed." + ) + + ordered_inputs = dict(zip(self.input_names, inputs)) + + sample = ordered_inputs.pop("sample", None) + encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) + timestep = ordered_inputs.pop("timestep", None) + encoder_attention_mask = ordered_inputs.pop("encoder_attention_mask", None) + + # Additional conditions + out_tuple = self.model( + hidden_states=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs={"resolution": None, "aspect_ratio": None}, + return_dict=False, + ) + + return out_tuple + + class ControlNetNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str]): super().__init__() @@ -121,8 +155,32 @@ def forward(self, *inputs): return out_tuple -# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html +# Adapted from https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/inference/hf_pretrained_pixart_alpha_inference_on_inf2.ipynb +# For text encoding class T5EncoderWrapper(torch.nn.Module): + def __init__( + self, model: "PreTrainedModel", sequence_length: int, batch_size: Optional[int] = None, device: str = "cpu" + ): + super().__init__() + self.model = model + self.config = model.config + self.sequence_length = sequence_length + self.batch_size = batch_size + self.device = device + for block in self.model.encoder.block: + block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh") + precomputed_bias = ( + self.model.encoder.block[0].layer[0].SelfAttention.compute_bias(self.sequence_length, self.sequence_length) + ) + self.model.encoder.block[0].layer[0].SelfAttention.compute_bias = lambda *args, **kwargs: precomputed_bias + + def forward(self, input_ids, attention_mask): + return self.model(input_ids, attention_mask=attention_mask) + + +# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html +# For text encoding + KV cache initialization +class T5EncoderForSeq2SeqLMWrapper(torch.nn.Module): """Wrapper to trace the encoder and the kv cache initialization in the decoder.""" def __init__( diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index 690ea2bdf..248de9bb7 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -27,12 +27,14 @@ DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, + DIFFUSION_MODEL_TRANSFORMER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, DIFFUSION_MODEL_VAE_ENCODER_NAME, ENCODER_NAME, get_attention_scores_sd, get_attention_scores_sdxl, + neuron_scaled_dot_product_attention, ) from ...utils import ( DIFFUSERS_MINIMUM_VERSION, @@ -55,8 +57,8 @@ ) from diffusers import ( ControlNetModel, + DiffusionPipeline, ModelMixin, - StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, @@ -74,7 +76,7 @@ def build_stable_diffusion_components_mandatory_shapes( batch_size: Optional[int] = None, sequence_length: Optional[int] = None, - unet_num_channels: Optional[int] = None, + unet_or_transformer_num_channels: Optional[int] = None, vae_encoder_num_channels: Optional[int] = None, vae_decoder_num_channels: Optional[int] = None, height: Optional[int] = None, @@ -94,17 +96,17 @@ def build_stable_diffusion_components_mandatory_shapes( "height": height, "width": width, } - unet_input_shapes = { + unet_or_transformer_input_shapes = { "batch_size": batch_size * num_images_per_prompt, "sequence_length": sequence_length, - "num_channels": unet_num_channels, + "num_channels": unet_or_transformer_num_channels, "height": height, "width": width, } components_shapes = { "text_encoder": text_encoder_input_shapes, - "unet": unet_input_shapes, + "unet_or_transformer": unet_or_transformer_input_shapes, "vae_encoder": vae_encoder_input_shapes, "vae_decoder": vae_decoder_input_shapes, } @@ -112,10 +114,11 @@ def build_stable_diffusion_components_mandatory_shapes( return components_shapes -def get_stable_diffusion_models_for_export( - pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], +def get_diffusion_models_for_export( + pipeline: "DiffusionPipeline", text_encoder_input_shapes: Dict[str, int], unet_input_shapes: Dict[str, int], + transformer_input_shapes: Dict[str, int], vae_encoder_input_shapes: Dict[str, int], vae_decoder_input_shapes: Dict[str, int], dynamic_batch_size: Optional[bool] = False, @@ -134,12 +137,14 @@ def get_stable_diffusion_models_for_export( performance benefit (CLIP text encoder, VAE encoder, VAE decoder, Unet). Args: - pipeline ([`Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"]`]): + pipeline ([`"DiffusionPipeline"`]): The model to export. text_encoder_input_shapes (`Dict[str, int]`): Static shapes used for compiling text encoder. unet_input_shapes (`Dict[str, int]`): Static shapes used for compiling unet. + transformer_input_shapes (`Dict[str, int]`): + Static shapes used for compiling diffusion transformer. vae_encoder_input_shapes (`Dict[str, int]`): Static shapes used for compiling vae encoder. vae_decoder_input_shapes (`Dict[str, int]`): @@ -166,7 +171,7 @@ def get_stable_diffusion_models_for_export( `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronDefaultConfig`]`: A Dict containing the model and Neuron configs for the different components of the model. """ - models_for_export = get_submodels_for_export_stable_diffusion( + models_for_export = get_submodels_for_export_diffusion( pipeline=pipeline, lora_model_ids=lora_model_ids, lora_weight_names=lora_weight_names, @@ -213,28 +218,52 @@ def get_stable_diffusion_models_for_export( models_for_export[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = (text_encoder_2, text_encoder_neuron_config_2) # U-NET - unet = models_for_export[DIFFUSION_MODEL_UNET_NAME] - unet_neuron_config_constructor = TasksManager.get_exporter_config_constructor( - model=unet, - exporter="neuron", - task="semantic-segmentation", - model_type="unet", - library_name=library_name, - ) - unet_neuron_config = unet_neuron_config_constructor( - unet.config, - task="semantic-segmentation", - dynamic_batch_size=dynamic_batch_size, - **unet_input_shapes, - ) - is_stable_diffusion_xl = isinstance( - pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline) - ) - unet_neuron_config.is_sdxl = is_stable_diffusion_xl + if DIFFUSION_MODEL_UNET_NAME in models_for_export: + unet = models_for_export[DIFFUSION_MODEL_UNET_NAME] + unet_neuron_config_constructor = TasksManager.get_exporter_config_constructor( + model=unet, + exporter="neuron", + task="semantic-segmentation", + model_type="unet", + library_name=library_name, + ) + unet_neuron_config = unet_neuron_config_constructor( + unet.config, + task="semantic-segmentation", + dynamic_batch_size=dynamic_batch_size, + float_dtype=unet.dtype, + **unet_input_shapes, + ) + is_stable_diffusion_xl = isinstance( + pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline) + ) + unet_neuron_config.is_sdxl = is_stable_diffusion_xl + + unet_neuron_config.with_controlnet = True if controlnet_ids else False - unet_neuron_config.with_controlnet = True if controlnet_ids else False + models_for_export[DIFFUSION_MODEL_UNET_NAME] = (unet, unet_neuron_config) - models_for_export[DIFFUSION_MODEL_UNET_NAME] = (unet, unet_neuron_config) + # Diffusion Transformer + transformer = None + if DIFFUSION_MODEL_TRANSFORMER_NAME in models_for_export: + transformer = models_for_export[DIFFUSION_MODEL_TRANSFORMER_NAME] + model_type = get_diffusers_submodel_type(transformer) + transformer_neuron_config_constructor = TasksManager.get_exporter_config_constructor( + model=transformer, + exporter="neuron", + task="semantic-segmentation", + model_type=model_type, + library_name=library_name, + ) + transformer.config.export_model_type = model_type + transformer_neuron_config = transformer_neuron_config_constructor( + transformer.config, + task="semantic-segmentation", + dynamic_batch_size=dynamic_batch_size, + float_dtype=transformer.dtype, + **transformer_input_shapes, + ) + models_for_export[DIFFUSION_MODEL_TRANSFORMER_NAME] = (transformer, transformer_neuron_config) # VAE Encoder vae_encoder = models_for_export[DIFFUSION_MODEL_VAE_ENCODER_NAME] @@ -249,6 +278,7 @@ def get_stable_diffusion_models_for_export( vae_encoder.config, task="semantic-segmentation", dynamic_batch_size=dynamic_batch_size, + float_dtype=vae_encoder.dtype, **vae_encoder_input_shapes, ) models_for_export[DIFFUSION_MODEL_VAE_ENCODER_NAME] = (vae_encoder, vae_encoder_neuron_config) @@ -266,6 +296,7 @@ def get_stable_diffusion_models_for_export( vae_decoder.config, task="semantic-segmentation", dynamic_batch_size=dynamic_batch_size, + float_dtype=transformer.dtype if transformer else vae_decoder.dtype, **vae_decoder_input_shapes, ) models_for_export[DIFFUSION_MODEL_VAE_DECODER_NAME] = (vae_decoder, vae_decoder_neuron_config) @@ -288,6 +319,7 @@ def get_stable_diffusion_models_for_export( controlnet.config, task="semantic-segmentation", dynamic_batch_size=dynamic_batch_size, + float_dtype=controlnet.dtype, **controlnet_input_shapes, ) models_for_export[controlnet_name] = ( @@ -299,7 +331,7 @@ def get_stable_diffusion_models_for_export( def _load_lora_weights_to_pipeline( - pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], + pipeline: "DiffusionPipeline", lora_model_ids: Optional[Union[str, List[str]]] = None, weight_names: Optional[Union[str, List[str]]] = None, adapter_names: Optional[Union[str, List[str]]] = None, @@ -352,8 +384,8 @@ def load_controlnets(controlnet_ids: Optional[Union[str, List[str]]] = None): return contronets -def get_submodels_for_export_stable_diffusion( - pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], +def get_submodels_for_export_diffusion( + pipeline: "DiffusionPipeline", output_hidden_states: bool = False, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, @@ -378,10 +410,6 @@ def get_submodels_for_export_stable_diffusion( ) models_for_export = [] - if hasattr(pipeline, "text_encoder_2"): - projection_dim = pipeline.text_encoder_2.config.projection_dim - else: - projection_dim = pipeline.text_encoder.config.projection_dim # Text encoders if pipeline.text_encoder is not None: @@ -394,28 +422,52 @@ def get_submodels_for_export_stable_diffusion( text_encoder_2.config.output_hidden_states = True text_encoder_2.text_model.config.output_hidden_states = True models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, copy.deepcopy(text_encoder_2))) + projection_dim = getattr(pipeline.text_encoder_2.config, "projection_dim", None) + else: + projection_dim = getattr(pipeline.text_encoder.config, "projection_dim", None) # U-NET - pipeline.unet.config.text_encoder_projection_dim = projection_dim - # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` - # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 - pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) - - # Replace original cross-attention module with custom cross-attention module for better performance - # For applying optimized attention score, we need to set env variable `NEURON_FUSE_SOFTMAX=1` - if os.environ.get("NEURON_FUSE_SOFTMAX") == "1": - if is_stable_diffusion_xl: - logger.info("Applying optimized attention score computation for sdxl.") - Attention.get_attention_scores = get_attention_scores_sdxl + unet = getattr(pipeline, "unet", None) + if unet is not None: + # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` + # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 + unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + unet.config.text_encoder_projection_dim = projection_dim + + # Replace original cross-attention module with custom cross-attention module for better performance + # For applying optimized attention score, we need to set env variable `NEURON_FUSE_SOFTMAX=1` + if os.environ.get("NEURON_FUSE_SOFTMAX") == "1": + if is_stable_diffusion_xl: + logger.info("Applying optimized attention score computation for sdxl.") + Attention.get_attention_scores = get_attention_scores_sdxl + else: + logger.info("Applying optimized attention score computation for stable diffusion.") + Attention.get_attention_scores = get_attention_scores_sd else: - logger.info("Applying optimized attention score computation for stable diffusion.") - Attention.get_attention_scores = get_attention_scores_sd - else: - logger.warning( - "You are not applying optimized attention score computation. If you want better performance, please" - " set the environment variable with `export NEURON_FUSE_SOFTMAX=1` and recompile the unet model." - ) - models_for_export.append((DIFFUSION_MODEL_UNET_NAME, copy.deepcopy(pipeline.unet))) + logger.warning( + "You are not applying optimized attention score computation. If you want better performance, please" + " set the environment variable with `export NEURON_FUSE_SOFTMAX=1` and recompile the unet model." + ) + models_for_export.append((DIFFUSION_MODEL_UNET_NAME, copy.deepcopy(unet))) + + # Diffusion transformer + transformer = getattr(pipeline, "transformer", None) + if transformer is not None: + transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + transformer.config.text_encoder_projection_dim = projection_dim + # apply optimized scaled_dot_product_attention + sdpa_original = torch.nn.functional.scaled_dot_product_attention + + def attention_wrapper(query, key, value, attn_mask=None, dropout_p=None, is_causal=None): + if attn_mask is not None: + return sdpa_original(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + else: + return neuron_scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + + torch.nn.functional.scaled_dot_product_attention = attention_wrapper + models_for_export.append((DIFFUSION_MODEL_TRANSFORMER_NAME, copy.deepcopy(transformer))) if pipeline.vae.config.get("force_upcast", None) is True: pipeline.vae.to(dtype=torch.float32) @@ -427,7 +479,10 @@ def get_submodels_for_export_stable_diffusion( # VAE Decoder vae_decoder = copy.deepcopy(pipeline.vae) + unet_or_transformer = unet or transformer vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) + if vae_decoder.dtype is torch.float32 and unet_or_transformer.dtype is not torch.float32: + vae_decoder = apply_fp32_wrapper_to_vae_decoder(vae_decoder) models_for_export.append((DIFFUSION_MODEL_VAE_DECODER_NAME, vae_decoder)) # ControlNets @@ -463,6 +518,37 @@ def replace_stable_diffusion_submodels(pipeline, submodels): return pipeline +def apply_fp32_wrapper_to_vae_decoder(model): + class f32Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.original = model + + def forward(self, x): + y = x.to(torch.float32) + output = self.original(y) + return output + + def __getattr__(self, name): + # Delegate attribute/method lookup to the wrapped model if not found in this wrapper + if name == "original": + return super().__getattr__(name) + return getattr(self.original, name) + + model = f32Wrapper(model) + return model + + +# TODO: get it into https://github.com/huggingface/optimum/blob/4a7cb298140ee9bed968d98a780a950d15bb2935/optimum/exporters/utils.py#L77 +_DIFFUSERS_CLASS_NAME_TO_SUBMODEL_TYPE = { + "PixArtTransformer2DModel": "pixart-transformer-2d", +} + + +def get_diffusers_submodel_type(submodel): + return _DIFFUSERS_CLASS_NAME_TO_SUBMODEL_TYPE.get(submodel.__class__.__name__) + + def get_encoder_decoder_models_for_export( model: "PreTrainedModel", task: str, diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 67263fd68..6900143d9 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -48,7 +48,7 @@ "NeuronModelForXVector", ], "modeling_diffusion": [ - "NeuronStableDiffusionPipelineBase", + "NeuronDiffusionPipelineBase", "NeuronStableDiffusionPipeline", "NeuronStableDiffusionImg2ImgPipeline", "NeuronStableDiffusionInpaintPipeline", @@ -59,6 +59,7 @@ "NeuronStableDiffusionXLInpaintPipeline", "NeuronStableDiffusionControlNetPipeline", "NeuronStableDiffusionXLControlNetPipeline", + "NeuronPixArtAlphaPipeline", ], "modeling_decoder": ["NeuronDecoderModel"], "modeling_seq2seq": ["NeuronModelForSeq2SeqLM"], @@ -94,13 +95,14 @@ ) from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( + NeuronDiffusionPipelineBase, NeuronLatentConsistencyModelPipeline, + NeuronPixArtAlphaPipeline, NeuronStableDiffusionControlNetPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, NeuronStableDiffusionInstructPix2PixPipeline, NeuronStableDiffusionPipeline, - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionXLControlNetPipeline, NeuronStableDiffusionXLImg2ImgPipeline, NeuronStableDiffusionXLInpaintPipeline, diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index cbc9968cc..1e8add2fa 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -449,7 +449,6 @@ def attention_forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 682c43272..7fbd8467b 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -1388,7 +1388,6 @@ def parameter_can_be_initialized(model: torch.nn.Module, parent_module: torch.nn def create_wrapper_for_resize_token_embedding(orig_resize_token_embeddings): - @functools.wraps(orig_resize_token_embeddings) def wrapper( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index a5c0cdd2a..0fb93247c 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -25,7 +25,6 @@ class NeuronModel(OptimizedModel): - def __init__(self, model: "PreTrainedModel", config: "PretrainedConfig"): super().__init__(model, config) if hasattr(model, "device"): diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 57081834e..b3bec452c 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -89,7 +89,7 @@ StableDiffusionXLPipeline, ) from diffusers.configuration_utils import FrozenDict - from diffusers.image_processor import VaeImageProcessor + from diffusers.image_processor import PixArtImageProcessor, VaeImageProcessor from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.controlnet import ControlNetOutput from diffusers.models.modeling_outputs import AutoencoderKLOutput @@ -105,6 +105,11 @@ NeuronStableDiffusionXLPipelineMixin, ) + os.environ["NEURON_FUSE_SOFTMAX"] = "1" + os.environ["NEURON_CUSTOM_SILU"] = "1" +else: + raise ModuleNotFoundError("`diffusers` python package is not installed.") + if TYPE_CHECKING: from ..exporters.neuron import NeuronDefaultConfig @@ -316,7 +321,7 @@ def __init__( # change lcm scheduler which extends the denoising procedure self.is_lcm = False - if NeuronDiffusionPipelineBase.is_lcm(self.unet.config): + if self.unet and NeuronDiffusionPipelineBase.is_lcm(self.unet.config): self.is_lcm = True self.scheduler = LCMScheduler.from_config(self.scheduler.config) @@ -358,13 +363,14 @@ def __init__( else: self.vae_scale_factor = 8 - unet_batch_size = self.neuron_configs["unet"].batch_size + unet_or_transformer = "transformer" if self.transformer else "unet" + unet_or_transformer_batch_size = self.neuron_configs[unet_or_transformer].batch_size if "text_encoder" in self.neuron_configs: text_encoder_batch_size = self.neuron_configs["text_encoder"].batch_size - self.num_images_per_prompt = unet_batch_size // text_encoder_batch_size + self.num_images_per_prompt = unet_or_transformer_batch_size // text_encoder_batch_size elif "text_encoder_2" in self.neuron_configs: text_encoder_batch_size = self.neuron_configs["text_encoder_2"].batch_size - self.num_images_per_prompt = unet_batch_size // text_encoder_batch_size + self.num_images_per_prompt = unet_or_transformer_batch_size // text_encoder_batch_size else: self.num_images_per_prompt = 1 @@ -771,11 +777,12 @@ def _from_pretrained( data_parallel_mode=data_parallel_mode, text_encoder_path=model_and_config_save_paths["text_encoder"][0], unet_path=model_and_config_save_paths["unet"][0], + transformer_path=model_and_config_save_paths["transformer"][0], vae_decoder_path=model_and_config_save_paths["vae_decoder"][0], vae_encoder_path=model_and_config_save_paths["vae_encoder"][0], text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0], controlnet_paths=model_and_config_save_paths["controlnet"][0], - dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size, + dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME].dynamic_batch_size, to_neuron=not inline_weights_to_neff, ) @@ -814,6 +821,7 @@ def _export( cls, model_id: Union[str, Path], config: Dict[str, Any], + torch_dtype: Optional[Union[str, torch.dtype]] = None, unet_id: Optional[Union[str, Path]] = None, token: Optional[Union[bool, str]] = None, revision: str = "main", @@ -851,6 +859,8 @@ def _export( config (`Dict[str, Any]`): A config dictionary from which the model components will be instantiated. Make sure to only load configuration files of compatible classes. + torch_dtype (`Optional[Union[str, torch.dtype]]`, defaults to `None`): + Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the dtype will be automatically derived from the model's weights. unet_id (`Optional[Union[str, Path]]`, defaults to `None`): A string or a path point to the U-NET model to replace the one in the original pipeline. token (`Optional[Union[bool, str]]`, defaults to `None`): @@ -933,6 +943,7 @@ def _export( subfolder=subfolder, revision=revision, framework="pt", + torch_dtype=torch_dtype, library_name=cls.library_name, cache_dir=cache_dir, token=token, @@ -969,6 +980,7 @@ def _export( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + torch_dtype=torch_dtype, controlnet_ids=controlnet_ids, **input_shapes_copy, ) @@ -1025,6 +1037,7 @@ def _export( model_name_or_path=model_id, output=save_dir_path, compiler_kwargs=compiler_kwargs, + torch_dtype=torch_dtype, task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, @@ -1099,13 +1112,17 @@ def do_classifier_free_guidance(self): ) def __call__(self, *args, **kwargs): - # Height and width to unet (static shapes) - height = self.unet.config.neuron["static_height"] * self.vae_scale_factor - width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + # Height and width to unet/transformer (static shapes) + unet_or_transformer = self.unet or self.transformer + height = unet_or_transformer.config.neuron["static_height"] * self.vae_scale_factor + width = unet_or_transformer.config.neuron["static_width"] * self.vae_scale_factor kwargs.pop("height", None) kwargs.pop("width", None) if kwargs.get("image", None): kwargs["image"] = self.image_processor.preprocess(kwargs["image"], height=height, width=width) + # Override default `max_sequence_length`, eg. pixart + if "max_sequence_length" in inspect.signature(self.auto_model_class.__call__).parameters: + kwargs["max_sequence_length"] = self.text_encoder.config.neuron.get("static_sequence_length", None) return self.auto_model_class.__call__(self, height=height, width=width, *args, **kwargs) @@ -1162,18 +1179,16 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, ): - if attention_mask is not None: - assert torch.equal( - torch.ones_like(attention_mask), attention_mask - ), "attention_mask is expected to be only all ones." if output_hidden_states: assert ( self.config.output_hidden_states or self.config.neuron.get("output_hidden_states") ) == output_hidden_states, "output_hidden_states is expected to be False since the model was compiled without hidden_states as output." input_ids = input_ids.to(torch.long) # dummy generator uses long int for tracing - inputs = (input_ids,) + if attention_mask is not None and not torch.all(attention_mask == 1): + inputs += (attention_mask,) + outputs = self.model(*inputs) if return_dict: @@ -1253,7 +1268,11 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): - pass + inputs = (hidden_states, encoder_hidden_states, timestep, encoder_attention_mask) + outputs = self.model(*inputs) + if return_dict: + outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs))) + return outputs class NeuronModelVaeEncoder(_NeuronDiffusionModelPart): @@ -1480,6 +1499,10 @@ class NeuronPixArtAlphaPipeline(NeuronDiffusionPipelineBase, PixArtAlphaPipeline main_input_name = "prompt" auto_model_class = PixArtAlphaPipeline + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + class NeuronStableDiffusionXLPipeline( NeuronStableDiffusionXLPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLPipeline diff --git a/optimum/neuron/models/qwen2/modules.py b/optimum/neuron/models/qwen2/modules.py index c4ef4a219..e655985c6 100644 --- a/optimum/neuron/models/qwen2/modules.py +++ b/optimum/neuron/models/qwen2/modules.py @@ -18,7 +18,6 @@ class Qwen2ForCausalLM(module.PretrainedModel): - def __init__(self, config: Qwen2Config): super().__init__() dtype, _, _ = utils.parse_amp(config.amp) @@ -34,7 +33,6 @@ def get_base_model(self): class Qwen2Model(module.LowMemoryModule): - def __init__(self, config: Qwen2Config): super().__init__() self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) @@ -43,14 +41,12 @@ def __init__(self, config: Qwen2Config): class Qwen2RMSNorm(module.LowMemoryModule): - def __init__(self, config: Qwen2Config) -> None: super().__init__() self.weight = module.UninitializedParameter() class Qwen2DecoderLayer(module.LowMemoryModule): - def __init__(self, config: Qwen2Config): super().__init__() self.self_attn = Qwen2Attention(config) @@ -60,7 +56,6 @@ def __init__(self, config: Qwen2Config): class Qwen2Attention(module.LowMemoryModule): - def __init__(self, config: Qwen2Config): super().__init__() self.hidden_size = config.hidden_size @@ -75,7 +70,6 @@ def __init__(self, config: Qwen2Config): class Qwen2MLP(module.LowMemoryModule): - def __init__(self, config: Qwen2Config): super().__init__() dtype, _, _ = utils.parse_amp(config.amp) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 25edb1102..6d9f25348 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -2182,7 +2182,6 @@ def odds_ratio_loss( policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - # Neuron-specific change compared to the original implementation in `trl`, the original implementation is: # # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 965aeaaaa..5f8b7fca0 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -56,11 +56,13 @@ "is_main_worker", "is_precompilation", "replace_weights", + "map_torch_dtype", ], "model_utils": ["get_tied_parameters_dict", "tie_parameters"], "optimization_utils": [ "get_attention_scores_sd", "get_attention_scores_sdxl", + "neuron_scaled_dot_product_attention", ], "patching": [ "DynamicPatch", @@ -115,12 +117,14 @@ get_stable_diffusion_configs, is_main_worker, is_precompilation, + map_torch_dtype, replace_weights, ) from .model_utils import get_tied_parameters_dict, tie_parameters from .optimization_utils import ( get_attention_scores_sd, get_attention_scores_sdxl, + neuron_scaled_dot_product_attention, ) from .patching import ( DynamicPatch, diff --git a/optimum/neuron/utils/hub_cache_utils.py b/optimum/neuron/utils/hub_cache_utils.py index 08360c312..ad2add627 100644 --- a/optimum/neuron/utils/hub_cache_utils.py +++ b/optimum/neuron/utils/hub_cache_utils.py @@ -512,6 +512,8 @@ def get_multimodels_configs_from_hub(model_id): if "unet" in lookup_configs: lookup_configs["model_type"] = "stable-diffusion" + if "transformer" in lookup_configs: + lookup_configs["model_type"] = "diffusion-transformer" return lookup_configs @@ -558,6 +560,9 @@ def build_cache_config( if "unet" in configs: # stable diffusion clean_configs["model_type"] = "stable-diffusion" + elif "transformer" in configs: + # diffusion transformer + clean_configs["model_type"] = "diffusion-transformer" else: # seq-to-seq clean_configs["model_type"] = next(iter(clean_configs.values()))["model_type"] diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index a845ce43e..9c7d17fd7 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -677,3 +677,19 @@ def get_stable_diffusion_configs( configs[name] = models_for_export[name].config return configs + + +def map_torch_dtype(dtype: Union[str, torch.dtype]): + dtype_mapping = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + } + + if isinstance(dtype, str) and dtype in dtype_mapping: + dtype = dtype_mapping.get(dtype) + + return dtype diff --git a/optimum/neuron/utils/model_utils.py b/optimum/neuron/utils/model_utils.py index a74932674..784981568 100644 --- a/optimum/neuron/utils/model_utils.py +++ b/optimum/neuron/utils/model_utils.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: - if is_torch_neuronx_available(): from neuronx_distributed.pipeline import NxDPPModel diff --git a/optimum/neuron/utils/optimization_utils.py b/optimum/neuron/utils/optimization_utils.py index feff8e98c..7017bcd11 100644 --- a/optimum/neuron/utils/optimization_utils.py +++ b/optimum/neuron/utils/optimization_utils.py @@ -14,6 +14,8 @@ # limitations under the License. """Optimization utilities.""" +import math + import torch @@ -69,3 +71,22 @@ def _custom_badbmm(a, b, scale): attention_probs = attention_scores.softmax(dim=-1) return attention_probs + + +def neuron_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=None, is_causal=None): + orig_shape = None + if len(query.shape) == 4: + orig_shape = query.shape + + def to3d(x): + return x.reshape(-1, x.shape[2], x.shape[3]) + + query, key, value = map(to3d, [query, key, value]) + attention_scores = torch.bmm(key, query.transpose(-1, -2)) * (1 / math.sqrt(query.size(-1))) + attention_probs = attention_scores.softmax(dim=1) + if query.size() == key.size(): + attention_probs = attention_probs.permute(0, 2, 1) + attn_out = torch.bmm(attention_probs, value) + if orig_shape: + attn_out = attn_out.reshape(orig_shape[0], orig_shape[1], attn_out.shape[1], attn_out.shape[2]) + return attn_out diff --git a/optimum/neuron/utils/peft_utils.py b/optimum/neuron/utils/peft_utils.py index 8c753ec38..34c414701 100644 --- a/optimum/neuron/utils/peft_utils.py +++ b/optimum/neuron/utils/peft_utils.py @@ -43,7 +43,6 @@ ) else: - SAFETENSORS_WEIGHTS_NAME = WEIGHTS_NAME = "" class PeftModel: diff --git a/optimum/neuron/utils/trl_utils.py b/optimum/neuron/utils/trl_utils.py index 31041122f..49918afd1 100644 --- a/optimum/neuron/utils/trl_utils.py +++ b/optimum/neuron/utils/trl_utils.py @@ -43,7 +43,6 @@ class NeuronSFTConfig(NeuronTrainingArguments, SFTConfig): @dataclass class NeuronORPOConfig(NeuronTrainingArguments, ORPOConfig): - @property def neuron_cc_flags_model_type(self) -> Optional[str]: return None diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py index ac333fc9d..6ab40fcea 100644 --- a/tests/cli/test_export_cli.py +++ b/tests/cli/test_export_cli.py @@ -170,6 +170,36 @@ def test_stable_diffusion(self): check=True, ) + @requires_neuronx + def test_pixart(self): + model_ids = ["hf-internal-testing/tiny-pixart-alpha-pipe"] + for model_id in model_ids: + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + [ + "optimum-cli", + "export", + "neuron", + "--model", + model_id, + "--batch_size", + "1", + "--height", + "8", + "--width", + "8", + "--sequence_length", + "16", + "--num_images_per_prompt", + "1", + "--torch_dtype", + "bfloat16", + tempdir, + ], + shell=False, + check=True, + ) + @requires_neuronx def test_stable_diffusion_multi_lora(self): model_id = "hf-internal-testing/tiny-stable-diffusion-torch" @@ -196,7 +226,7 @@ def test_stable_diffusion_multi_lora(self): lora_model_id, "--lora_weight_names", lora_weight_name, - "lora_adapter_names", + "--lora_adapter_names", adpater_name, "--lora_scales", "0.9", diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 37f2e278d..616e31ce4 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -19,6 +19,7 @@ import cv2 import numpy as np import PIL +import torch from compel import Compel, ReturnedEmbeddingsType from diffusers import UniPCMultistepScheduler from diffusers.utils import load_image @@ -26,6 +27,7 @@ from optimum.neuron import ( NeuronLatentConsistencyModelPipeline, + NeuronPixArtAlphaPipeline, NeuronStableDiffusionControlNetPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, @@ -38,6 +40,7 @@ from optimum.neuron.modeling_diffusion import ( NeuronControlNetModel, NeuronModelTextEncoder, + NeuronModelTransformer, NeuronModelUnet, NeuronModelVaeDecoder, NeuronModelVaeEncoder, @@ -435,3 +438,40 @@ def test_from_pipe(self, model_arch): prompt = "a dog running, lake, moat" image = img2img_pipeline(prompt=prompt, image=init_image).images[0] self.assertIsInstance(image, PIL.Image.Image) + + +is_inferentia_test + + +@requires_neuronx +@require_diffusers +class NeuronPixArtAlphaPipelineIntegrationTest(unittest.TestCase): + ATOL_FOR_VALIDATION = 1e-3 + + def test_export_and_inference_non_dyn(self): + model_id = "hf-internal-testing/tiny-pixart-alpha-pipe" + compiler_args = {"auto_cast": "none"} + input_shapes = {"batch_size": 1, "height": 64, "width": 64, "sequence_length": 32} + neuron_pipeline = NeuronPixArtAlphaPipeline.from_pretrained( + model_id, + export=True, + torch_dtype=torch.bfloat16, + dynamic_batch_size=False, + disable_neuron_cache=True, + **input_shapes, + **compiler_args, + ) + self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.transformer, NeuronModelTransformer) + self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) + self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) + + prompt = "Mario eating hamburgers." + + neuron_pipeline.transformer.config.sample_size = ( + 32 # Skip the sample size check because the dummy model uses a smaller sample size (8). + ) + image = neuron_pipeline(prompt=prompt, use_resolution_binning=False).images[ + 0 + ] # Set `use_resolution_binning=False` to prevent resizing. + self.assertIsInstance(image, PIL.Image.Image) diff --git a/tests/test_trainers.py b/tests/test_trainers.py index 02d1eef1d..17f79248c 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -276,7 +276,6 @@ def test_train_and_eval_use_remote_cache(self, hub_test_with_local_cache, tmpdir @pytest.mark.skip("Test in later release") def test_save_and_resume_from_checkpoint(self, parallel_sizes, tmpdir): - tmpdir = Path(tmpdir) _, tp_size, pp_size = parallel_sizes train_batch_size = 2 diff --git a/tests/utils.py b/tests/utils.py index 060c77596..4f6cfc6d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -79,7 +79,6 @@ def get_random_string(length) -> str: def create_dummy_dataset(input_specs: Dict[str, Tuple[Tuple[int, ...], torch.dtype]], num_examples: int) -> Dataset: - def gen(): for _ in range(num_examples): yield {name: torch.rand(shape) for name, shape in input_specs.items()} diff --git a/text-generation-inference/tests/fixtures/service.py b/text-generation-inference/tests/fixtures/service.py index f108e600b..a0702985a 100644 --- a/text-generation-inference/tests/fixtures/service.py +++ b/text-generation-inference/tests/fixtures/service.py @@ -30,7 +30,6 @@ class TestClient(AsyncInferenceClient): - def __init__(self, service_name: str, base_url: str): super().__init__(model=base_url) self.service_name = service_name diff --git a/text-generation-inference/tests/integration/test_implicit_env.py b/text-generation-inference/tests/integration/test_implicit_env.py index bb090d10c..fa88ab673 100644 --- a/text-generation-inference/tests/integration/test_implicit_env.py +++ b/text-generation-inference/tests/integration/test_implicit_env.py @@ -38,7 +38,6 @@ async def tgi_service(request, launcher, neuron_model_config): @pytest.mark.asyncio async def test_model_single_request(tgi_service): - # Just verify that the generation works, and nothing is raised, with several set of params # No params diff --git a/text-generation-inference/tgi_env.py b/text-generation-inference/tgi_env.py index 430300969..6855b468a 100755 --- a/text-generation-inference/tgi_env.py +++ b/text-generation-inference/tgi_env.py @@ -127,7 +127,6 @@ def lookup_compatible_cached_model(model_id: str, revision: Optional[str]) -> Op def check_env_and_neuron_config_compatibility(neuron_config: Dict[str, Any], check_compiler_version: bool) -> bool: - logger.debug( "Checking the provided neuron config %s is compatible with the local setup and provided environment", neuron_config, diff --git a/tools/list_top_models.py b/tools/list_top_models.py index 49d90898d..78136164d 100644 --- a/tools/list_top_models.py +++ b/tools/list_top_models.py @@ -5,7 +5,6 @@ class ModelStats(HfApi): - class Sort: DOWNLOADS = "downloads" TRENDING = "trendingScore"