Releases: huggingface/diffusers
v0.3.0: New API, Stable Diffusion pipelines, low-memory inference, MPS backend, ONNX
📚 Shiny new docs!
Thanks to the community efforts for [Docs] and [Type Hints] we've started populating the Diffusers documentation pages with lots of helpful guides, links and API references.
📝 New API & breaking changes
New API
Pipeline, Model, and Scheduler outputs can now be both dataclasses, Dicts, and Tuples:
image = pipe("The red cat is sitting on a chair")["sample"][0]
is now replaced by:
image = pipe("The red cat is sitting on a chair").images[0]
# or
image = pipe("The red cat is sitting on a chair")["image"][0]
# or
image = pipe("The red cat is sitting on a chair")[0]
Similarly:
sample = unet(...).sample
and
prev_sample = scheduler(...).prev_sample
is now possible!
🚨🚨🚨 Breaking change 🚨🚨🚨
This PR introduces breaking changes for the following public-facing methods:
VQModel.encode
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changelatents = model.encode(...)
tolatents = model.encode(...)[0]
orlatents = model.encode(...).latens
VQModel.decode
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model.decode(...)
tosample = model.decode(...)[0]
orsample = model.decode(...).sample
VQModel.forward
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model(...)
tosample = model(...)[0]
orsample = model(...).sample
AutoencoderKL.encode
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changelatent_dist = model.encode(...)
tolatent_dist = model.encode(...)[0]
orlatent_dist = model.encode(...).latent_dist
AutoencoderKL.decode
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model.decode(...)
tosample = model.decode(...)[0]
orsample = model.decode(...).sample
AutoencoderKL.forward
-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model(...)
tosample = model(...)[0]
orsample = model(...).sample
🎨 New Stable Diffusion pipelines
A couple of new pipelines have been added to Diffusers! We invite you to experiment with them, and to take them as inspiration to create your cool new tasks. These are the new pipelines:
- Image-to-image generation. In addition to using a text prompt, this pipeline lets you include an example image to be used as the initial state of the process. 🤗 Diffuse the Rest is a cool demo about it!
- Inpainting (experimental). You can provide an image and a mask and ask Stable Diffusion to replace the mask.
For more details about how they work, please visit our new API documentation.
This is a summary of all the Stable Diffusion tasks that can be easily used with 🤗 Diffusers:
Pipeline | Tasks | Colab | Demo |
---|---|---|---|
pipeline_stable_diffusion.py | Text-to-Image Generation | 🤗 Stable Diffusion | |
pipeline_stable_diffusion_img2img.py | Image-to-Image Text-Guided Generation | 🤗 Diffuse the Rest | |
pipeline_stable_diffusion_inpaint.py | Experimental – Text-Guided Image Inpainting | Coming soon |
🍬 Less memory usage for smaller GPUs
Now the diffusion models can take up significantly less VRAM (3.2 GB for Stable Diffusion) at the expense of 10% of speed thanks to the optimizations discussed in basujindal/stable-diffusion#117.
To make use of the attention optimization, just enable it with .enable_attention_slicing()
after loading the pipeline:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
)
pipe = pipe.to("cuda")
pipe.enable_attention_slicing()
This will allow many more users to play with Stable Diffusion in their own computers! We can't wait to see what new ideas and results will be created by the community!
🐈⬛ Textual Inversion
Textual Inversion lets you personalize a Stable Diffusion model on your own images with just 3-5 samples.
GitHub: https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion
Training: https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb
Inference: https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
🍎 MPS backend for Apple Silicon
🤗 Diffusers is compatible with Apple silicon for Stable Diffusion inference, using the PyTorch mps
device. You need to install PyTorch Preview (Nightly) on a Mac with M1 or M2 CPU, and then use the pipeline as usual:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("mps")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
We are seeing great speedups (31s vs 214s in a M1 Max), but there are still a couple of limitations. We encourage you to read the documentation for the details.
🏭 Experimental ONNX exporter and pipeline for Stable Diffusion
We introduce a new (and experimental) Stable Diffusion pipeline compatible with the ONNX Runtime. This allows you to run Stable Diffusion on any hardware that supports ONNX (including a significant speedup on CPUs).
You need to use StableDiffusionOnnxPipeline
instead of StableDiffusionPipeline
. You also need to download the weights from the onnx
branch of the repository, and indicate the runtime provider you want to use (CPU, in the following example):
from diffusers import StableDiffusionOnnxPipeline
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="onnx",
provider="CPUExecutionProvider",
use_auth_token=True,
)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
To convert your own checkpoint, run the conversion script locally:
python scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path="CompVis/stable-diffusion-v1-4" --output_path="./stable_diffusion_onnx"
After that it can be loaded from the local path:
pipe = StableDiffusionOnnxPipeline.from_pretrained("./stable_diffusion_onnx", provider="CPUExecutionProvider")
Improvements and bugfixes
- Mark in painting experimental by @patrickvonplaten in #430
- Add config docs by @patrickvonplaten in #429
- [Docs] Models by @kashif in #416
- [Docs] Using diffusers by @patrickvonplaten in #428
- [Outputs] Improve syntax by @patrickvonplaten in #423
- Initial ONNX doc (TODO: Installation) by @pcuenca in #426
- [Tests] Correct image folder tests by @patrickvonplaten in #427
- [MPS] Make sure it doesn't break torch < 1.12 by @patrickvonplaten in #425
- [ONNX] Stable Diffusion exporter and pipeline by @anton-l in #399
- [Tests] Make image-based SD tests reproducible with fixed datasets by @anton-l in #424
- [Docs] Outputs.mdx by @patrickvonplaten in #422
- [Docs] Fix scheduler docs by @patrickvonplaten in #421
- [Docs] DiffusionPipeline by @patrickvonplaten in #418
- Improve unconditional diffusers example by @satpalsr in #414
- Improve latent diff example by @satpalsr in #413
- Inference support for
mps
device by @pcuenca in #355 - [Docs] Minor fixes in optimization section by @patrickvonplaten in #420
- [Docs]...
v0.2.4: Patch release
This patch release allows the Stable Diffusion pipelines to be loaded with float16
precision:
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
)
pipe = pipe.to("cuda")
The resulting models take up less than 6900 MiB
of GPU memory.
- [Loading] allow modules to be loaded in fp16 by @patrickvonplaten in #230
v0.2.3: Stable Diffusion public release
🎨 Stable Diffusion public release
The Stable Diffusion checkpoints are now public and can be loaded by anyone! 🥳
Make sure to accept the license terms on the model page first (requires login): https://huggingface.co/CompVis/stable-diffusion-v1-4
Install the required packages: pip install diffusers==0.2.3 transformers scipy
And log in on your machine using the huggingface-cli login
command.
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
# this will substitute the default PNDM scheduler for K-LMS
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=lms,
use_auth_token=True
).to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image.save("astronaut_rides_horse.png")
The safety checker
Following the model authors' guidelines and code, the Stable Diffusion inference results will now be filtered to exclude unsafe content. Any images classified as unsafe will be returned as blank. To check if the safety module is triggered programmaticaly, check the nsfw_content_detected
flag like so:
outputs = pipe(prompt)
image = outputs
if any(outputs["nsfw_content_detected"]):
print("Potential unsafe content was detected in one or more images. Try again with a different prompt and/or seed.")
Improvements and bugfixes
- add add_noise method in LMSDiscreteScheduler, PNDMScheduler by @patil-suraj in #227
- hotfix for pdnm test by @natolambert in #220
- Restore
is_modelcards_available
in.utils
by @pcuenca in #224 - Update README for 0.2.3 release by @pcuenca in #225
- Pipeline to device by @pcuenca in #210
- fix safety check by @patil-suraj in #217
- Add safety module by @patil-suraj in #213
- Support one-string prompts and custom image size in LDM by @anton-l in #212
- Add
is_torch_available
,is_flax_available
by @anton-l in #204 - Revive
make quality
by @anton-l in #203 - [StableDiffusionPipeline] use default params in call by @patil-suraj in #196
- fix test_from_pretrained_hub_pass_model by @patil-suraj in #194
- Match params with official Stable Diffusion lib by @apolinario in #192
Full Changelog: v0.2.2...v0.2.3
v0.2.2
This patch release fixes an import of the StableDiffusionPipeline
[K-LMS Scheduler] fix import by @patrickvonplaten in #191
v0.2.1 Patch release
This patch release fixes a small bug of the StableDiffusionPipeline
- [Stable diffusion] Hot fix by @patrickvonplaten in 50a9ae
v0.2.0: Stable Diffusion early access, K-LMS sampling
Stable Diffusion
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from CompVis, Stability AI and LAION. It's trained on 512x512 images from a subset of the LAION-5B database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
See the model card for more information.
The Stable Diffusion weights are currently only available to universities, academics, research institutions and independent researchers. Please request access applying to this form
from torch import autocast
from diffusers import StableDiffusionPipeline
# make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
prompt = "a photograph of an astronaut riding a horse"
with autocast("cuda"):
image = pipe(prompt, guidance_scale=7)["sample"][0] # image here is in PIL format
image.save(f"astronaut_rides_horse.png")
K-LMS sampling
The new LMSDiscreteScheduler
is a port of k-lms from k-diffusion by Katherine Crowson.
The scheduler can be easily swapped into existing pipelines like so:
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
model_id = "CompVis/stable-diffusion-v1-3-diffusers"
# Use the K-LMS scheduler here instead
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, use_auth_token=True)
Integration test with text-to-image script of Stable-Diffusion
#182 and #186 make sure that DDIM and PNDM/PLMS scheduler yield 1-to-1 the same results as stable diffusion.
Try it out yourself:
In Stable-Diffusion:
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --n_samples 4 --n_iter 1 --fixed_code --plms
or
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --n_samples 4 --n_iter 1 --fixed_code
In diffusers
:
from diffusers import StableDiffusionPipeline, DDIMScheduler
from time import time
from PIL import Image
from einops import rearrange
import numpy as np
import torch
from torch import autocast
from torchvision.utils import make_grid
torch.manual_seed(42)
prompt = "a photograph of an astronaut riding a horse"
#prompt = "a photograph of the eiffel tower on the moon"
#prompt = "an oil painting of a futuristic forest gives"
# uncomment to use DDIM
# scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True, scheduler=scheduler) # make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True) # make sure you're logged in with `huggingface-cli login`
all_images = []
num_rows = 1
num_columns = 4
for _ in range(num_rows):
with autocast("cuda"):
images = pipe(num_columns * [prompt], guidance_scale=7.5, output_type="np")["sample"] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
all_images.append(torch.from_numpy(images))
# additionally, save as grid
grid = torch.stack(all_images, 0)
grid = rearrange(grid, 'n b h w c -> (n b) h w c')
grid = rearrange(grid, 'n h w c -> n c h w')
grid = make_grid(grid, nrow=num_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
image = Image.fromarray(grid.astype(np.uint8))
image.save(f"./images/diffusers/{'_'.join(prompt.split())}_{round(time())}.png")
Improvements and bugfixes
- Allow passing non-default modules to pipeline by @pcuenca in #188
- Add K-LMS scheduler from k-diffusion by @anton-l in #185
- [Naming] correct config naming of DDIM pipeline by @patrickvonplaten in #187
- [PNDM] Stable diffusion by @patrickvonplaten in #186
- [Half precision] Make sure half-precision is correct by @patrickvonplaten in #182
- allow custom height, width in StableDiffusionPipeline by @patil-suraj in #179
- add tests for stable diffusion pipeline by @patil-suraj in #178
- Stable diffusion pipeline by @patil-suraj in #168
- [LDM pipeline] fix eta condition. by @patil-suraj in #171
- [PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs by @patil-suraj in #167
- allow pndm scheduler to be used with ldm pipeline by @patil-suraj in #165
- add scaled_linear schedule in PNDM and DDPM by @patil-suraj in #164
- add attention up/down blocks for VAE by @patil-suraj in #161
- Add an alternative Karras et al. stochastic scheduler for VE models by @anton-l in #160
- [LDMTextToImagePipeline] make text model generic by @patil-suraj in #162
- Minor typos by @pcuenca in #159
- Fix arg key for
dataset_name
increate_model_card
by @pcuenca in #158 - [VAE] fix the downsample block in Encoder. by @patil-suraj in #156
- [UNet2DConditionModel] add cross_attention_dim as an argument by @patil-suraj in #155
- Added
diffusers
to conda-forge and updated README for installation instruction by @sugatoray in #129 - Add issue templates for feature requests and bug reports by @osanseviero in #153
- Support training with a local image folder by @anton-l in #152
- Allow DDPM scheduler to use model's predicated variance by @eyalmazuz in #132
Full Changelog: 0.1.3...v0.2.0
0.1.3 Patch release
This patch releases refactors the model architecture of VQModel
or AutoencoderKL
including the weight naming. Therefore the official weights of the CompVis
organization have been re-uploaded, see:
- https://huggingface.co/CompVis/ldm-celebahq-256/commit/63b33cf3bbdd833de32080a8ba55ba4d0b111859
- https://huggingface.co/CompVis/ldm-celebahq-256/commit/03978f22272a3c2502da709c3940e227c9714bdd
- https://huggingface.co/CompVis/ldm-text2im-large-256/commit/31ff4edafd3ee09656d2068d05a4d5338129d592
- https://huggingface.co/CompVis/ldm-text2im-large-256/commit/9bd2b48d2d45e6deb6fb5a03eb2a601e4b95bd91
Corresponding PR: #137
Please make sure to upgrade diffusers
to have those models running correctly: pip install --upgrade diffusers
Bug fixes
- Fix
FileNotFoundError: 'model_card_template.md'
#136
Initial release of 🧨 Diffusers
These are the release notes of the 🧨 Diffusers library
Introducing Hugging Face's new library for diffusion models.
Diffusion models proved themselves very effective in artificial synthesis, even beating GANs for images. Because of that, they gained traction in the machine learning community and play an important role for systems like DALL-E 2 or Imagen to generate photorealistic images when prompted on text.
While the most prolific successes of diffusion models have been in the computer vision community, these models have also achieved remarkable results in other domains, such as:
and more.
Goals
The goals of diffusers are:
- to centralize the research of diffusion models from independent repositories to a clear and maintained project,
- to reproduce high impact machine learning systems such as DALLE and Imagen in a manner that is accessible for the public, and
- to create an easy to use API that enables one to train their own models or re-use checkpoints from other repositories for inference.
Release overview
Quickstart:
- For a light walk-through of the library, please have a look at the Official 🧨 Diffusers Notebook.
- To directly jump into training a diffusion model yourself, please have a look at the Training Diffusers Notebook
Diffusers aims to be a modular toolbox for diffusion techniques, with a focus the following categories:
🚄 Inference pipelines
Inference pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box. The goal is for them to stick as close as possible to their original implementation, and they can include components of other libraries (such as text encoders).
The original release contains the following pipelines:
- DDPM for unconditional image generation with discrete scheduling in pipeline_ddpm.
- DDIM for unconditional image generation with discrete scheduling in pipeline_ddim.
- PNDM for unconditional image generation with discrete scheduling in pipeline_pndm.
- Stochastic Differential Equations for unconditional image generation with continuous scheduling in score_sde_ve
- Latent diffusion for text to image generation / conditional image generation in pipeline_latent_diffusion as well as for unconditional image generation in latent_diffusion_uncond
We are currently working on enabling other pipelines for different modalities. The following pipelines are expected to land in a subsequent release:
- BDDMPipeline for spectrogram-to-sound vocoding
- GLIDEPipeline to support OpenAI's GLIDE model
- Grad-TTS for text to audio generation / conditional audio generation
- A reinforcement learning pipeline (happening in #105)
⏰ Schedulers
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangable between diffusion models in inference to find the preferred tradef-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
The goal is for each scheduler to provide one or more step()
functions that should be called iteratively to unroll the diffusion loop during the forward pass. They are framework agnostic, but offer conversion methods which should allow easy conversion to PyTorch utilities.
The initial release contains the following schedulers:
- DDIM, from the Denoising Diffusion Implicit Models paper.
- DDPM, from the Denoising Diffusion Probabilistic Models paper.
- PNDM, from the Pseudo Numerical Methods for Diffusion Models on Manifolds paper
- SDE_VE, from the Score-Based Generative Modeling through Stochastic Differential Equations paper.
🏭 Models
Models are hosted in the src/diffusers/models
folder.
For the initial release, you'll get to see a few building blocks, as well as some resulting models:
UNet2DModel
can be seen as a version of the recent UNet architectures as shown in recent papers. It can be seen as the unconditional version of the UNet model, in opposition to the conditional version that follows below.UNet2DConditionModel
is similar to theUNet2DModel
, but is conditional: it uses the cross-attention mechanism in order to have skip connections in its downsample and upsample layers. These cross-attentions can be fed by other models. An example of a pipeline using a conditional UNet model is the latent diffusion pipeline.AutoencoderKL
andVQModel
are still experimental models that are prone to breaking changes in the near future. However, they can already be used as part of the Latent Diffusion pipelines.
📃 Training example
The first release contains a dataset-agnostic unconditional example and a training notebook:
- The
train_unconditional.py
example, which trains a DDPM UNet model on a dataset of your choice. - More examples can be found under the Hugging Face Diffusers Notebooks
Credits
This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
- @CompVis' latent diffusion models library, available here
- @hojonathanho original DDPM implementation, available here as well as the extremely useful translation into PyTorch by @pesser, available here
- @ermongroup's DDIM implementation, available here.
- @yang-song's Score-VE and Score-VP implementations, available here
We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available here.