Skip to content

Commit

Permalink
Add pixart sigma and fix NOTthing (#762)
Browse files Browse the repository at this point in the history
# What does this PR do?

as per title
  • Loading branch information
JingyaHuang authored Jan 15, 2025
1 parent 4f05a29 commit 2117ef3
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
5 changes: 2 additions & 3 deletions docs/source/inference_tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ optimum-cli export neuron --model stabilityai/stable-diffusion-2-1-base \
--batch_size 1 \
--height 512 `# height in pixels of generated image, eg. 512, 768` \
--width 512 `# width in pixels of generated image, eg. 512, 768` \
--num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \
--num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \
--auto_cast matmul `# cast only matrix multiplication operations` \
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
sd_neuron/
Expand Down Expand Up @@ -231,7 +231,7 @@ optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
--batch_size 1 \
--height 1024 `# height in pixels of generated image, eg. 768, 1024` \
--width 1024 `# width in pixels of generated image, eg. 768, 1024` \
--num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \
--num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \
--auto_cast matmul `# cast only matrix multiplication operations` \
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
sd_neuron_xl/
Expand Down Expand Up @@ -340,7 +340,6 @@ prompt = "A majestic lion jumping from a big stone at night"
base = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/")
image = base(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",
Expand Down
1 change: 1 addition & 0 deletions docs/source/package_reference/supported_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ limitations under the License.
| SDXL Turbo | text-to-image, image-to-image, inpaint |
| LCM | text-to-image |
| PixArt-α | text-to-image |
| PixArt-Σ | text-to-image |

## Sentence Transformers

Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"NeuronStableDiffusionControlNetPipeline",
"NeuronStableDiffusionXLControlNetPipeline",
"NeuronPixArtAlphaPipeline",
"NeuronPixArtSigmaPipeline",
],
"modeling_decoder": ["NeuronDecoderModel"],
"modeling_seq2seq": ["NeuronModelForSeq2SeqLM"],
Expand Down Expand Up @@ -98,6 +99,7 @@
NeuronDiffusionPipelineBase,
NeuronLatentConsistencyModelPipeline,
NeuronPixArtAlphaPipeline,
NeuronPixArtSigmaPipeline,
NeuronStableDiffusionControlNetPipeline,
NeuronStableDiffusionImg2ImgPipeline,
NeuronStableDiffusionInpaintPipeline,
Expand Down
10 changes: 10 additions & 0 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
LatentConsistencyModelPipeline,
LCMScheduler,
PixArtAlphaPipeline,
PixArtSigmaPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down Expand Up @@ -1504,6 +1505,15 @@ def __init__(self, **kwargs):
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)


class NeuronPixArtSigmaPipeline(NeuronDiffusionPipelineBase, PixArtSigmaPipeline):
main_input_name = "prompt"
auto_model_class = PixArtSigmaPipeline

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)


class NeuronStableDiffusionXLPipeline(
NeuronStableDiffusionXLPipelineMixin, NeuronDiffusionPipelineBase, StableDiffusionXLPipeline
):
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_stable_diffusion(self):
"--width",
"64",
"--num_images_per_prompt",
"4",
"1",
"--auto_cast",
"matmul",
"--auto_cast_type",
Expand Down

0 comments on commit 2117ef3

Please sign in to comment.