Skip to content

Commit

Permalink
Added readme and example notebook for diffusion_labs (#495)
Browse files Browse the repository at this point in the history
Summary:
Added a readme and an example Notebook for training to diffusion labs.

Pull Request resolved: #495

Test Plan: These are only documentation changes

Reviewed By: abhinavarora

Differential Revision: D50381163

Pulled By: pbontrager

fbshipit-source-id: 3974968e70551c74241aaaac0c22f0260e5ca368
  • Loading branch information
Philip Bontrager authored and facebook-github-bot committed Oct 17, 2023
1 parent 7968f32 commit 2ddb8cd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
55 changes: 55 additions & 0 deletions torchmultimodal/diffusion_labs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# diffusion_labs

Diffusion labs provides components for building diffusion models and for end-to-end training of those models. This
includes definitions for popular models such as

- Dalle2
- Latent Diffusion Models (LDM)

and all the components needed for defining and training these models. All of these modules are compatible with
Pytorch distributed and PT2.

# Concepts

1. Models

This includes diffusion model definitions, like LDM, as well as models used within the diffusion model such as a
U-Net or Transformer. A common model used for denoising within diffusion training is the U-Net from
[ADM](https://arxiv.org/abs/2105.05233), which is available at `diffusion_labs/models/adm_unet`.

2. Adapters

Adapters adapt the underlying architecture to handle various types of conditional inputs both at training and
inference time. They act as wrappers around the model and multiple adapters can be wrapped around each other to
handle multiple types of inputs. All Adapters have the same `forward` signature allowing them to be stacked.

3. Predictor

Predictor defines what the model is trained to predict (e.g. added noise or a clean image). This is used to convert
the model output into a denoised data point.

4. Schedule

The schedule defines the diffusion process being applied to the data. This includes defining what kind of noise,
and how much noise to apply to each diffusion step. The Schedule class contains the noise values along with any
necessary computations related to it.

5. Sampler

The sampler wraps around the model to denoise the input data given the diffusion schedule. This class takes is
defined with the model, the Predictor and the Schedule as inputs. In train mode the Sampler calls the model for one
step while in eval mode it will call the model for the entire diffusion schedule.


6. Transform

diffusion_labs introduces several helper transforms for diffusion that can be used in conjunction with other data
transforms such as vision transforms. All transforms are implemented as nn.Modules and take in a dict of data and
then output and updated dict. This allows all transforms to be stacked together with nn.Sequential and to be
compiled.


# Tutorial

[How to train diffusion on
MNIST](https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/diffusion_labs/mnist_training.ipynb)
1 change: 1 addition & 0 deletions torchmultimodal/diffusion_labs/mnist_training.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"metadata":{"dataExplorerConfig":{},"bento_stylesheets":{"bento/extensions/flow/main.css":true,"bento/extensions/kernel_selector/main.css":true,"bento/extensions/kernel_ui/main.css":true,"bento/extensions/new_kernel/main.css":true,"bento/extensions/system_usage/main.css":true,"bento/extensions/theme/main.css":true},"kernelspec":{"name":"bento_kernel_torchmultimodal","display_name":"TorchMultimodal","language":"python","metadata":{"kernel_name":"bento_kernel_torchmultimodal","nightly_builds":true,"fbpkg_supported":true,"cinder_runtime":true,"is_prebuilt":true},"isCinder":true},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3"},"last_server_session_id":"e6873c2b-2361-410f-8f06-1d6b18713432","last_kernel_id":"7409c057-f3b4-4d2e-829f-c604f6bc1bf4","last_base_url":"https://bento.edge.x2p.facebook.net/","last_msg_id":"dbd81be2-1c3636c832adfab75199f41d_437","captumWidgetMessage":{},"outputWidgetContext":{}},"nbformat":4,"nbformat_minor":2,"cells":[{"cell_type":"markdown","metadata":{"originalKey":"e091e8e8-88ab-4394-9496-6150044e051d","showInput":false,"customInput":null},"source":[" # Conditional Diffusion MNIST Image Generation"]},{"cell_type":"code","metadata":{"collapsed":false,"originalKey":"a637961f-865a-4ef0-8cb7-8041feed3a1b","outputsInitialized":false,"requestMsgId":"14f3a39d-1975-4061-9bb5-d9f37c318261","executionStartTime":1697568515546,"executionStopTime":1697568519648,"customOutput":null},"source":["import torch\n","import torchvision\n","import torchvision.transforms.functional as F\n","\n","from torch import nn\n","from tqdm import tqdm\n","from torchmultimodal.diffusion_labs.models.adm_unet.adm import adm_unet\n","from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance\n","from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import DiffusionHybridLoss\n","from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule\n","from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor\n","from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import linear_beta_schedule, DiscreteGaussianSchedule\n","from torchmultimodal.diffusion_labs.transforms.diffusion_transform import RandomDiffusionSteps\n","\n","device = \"cuda\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"b6751f28-2255-495a-9cb0-21d8b9b680bc","showInput":false,"customInput":null},"source":["# Define Model"]},{"cell_type":"markdown","metadata":{"originalKey":"fa6eba21-d5a8-4150-9e2e-61d547083a29","showInput":false,"customInput":null},"source":["To define a diffusion model you need to define four primary components:\n","\n","1. Network and Adapters\n","2. Diffusion Schedule\n","3. Predictor\n","4. Sampler\n","\n","The network typically used with image diffusion models is a [U-Net](https://paperswithcode.com/method/u-net). A U-Net is a convolutional network that maps the input space directly to a equal sized output space. This makes it ideal for image segmentation and transformation tasks. Since we are denosing an image, a U-Net works very well here. [ADMUnet](https://arxiv.org/abs/2105.05233) is a specific implementation shown to work well for image generation. \n","\n","The default values for adm_unet are a bit over-kill for the tiny MNIST dataset so we'll choose some custom smaller values here."]},{"cell_type":"code","metadata":{"originalKey":"b63e91be-81b7-4a31-add8-7fddad09bf4a","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"8b5aa7a4-6b1c-4adb-a7bf-070083fdbe3e","executionStartTime":1697567110967,"executionStopTime":1697567111933,"customOutput":null},"source":["unet = adm_unet(\n"," time_embed_dim=128, # Model takes diffusion timestep as a conditional input\n"," cond_embed_dim=128, # Projected size of conditional embedding\n"," embed_dim=768, # Size of conditional embedding for conditional image generation\n"," embed_name=\"digit\", # Name of conditional input\n"," predict_variance_value=True, # If the model should learn per step variance values for sampling\n"," image_channels=1, # MNIST images are single channel\n"," depth=128, # U-Net layer depth\n"," num_resize=3, # Number of upsample/downsampler blocks for U-Net\n"," num_res_per_layer=3, # Residual Blocks per channel.\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"2a894bdd-de9f-46a6-8e6d-f94f2428a229","showInput":false,"customInput":null},"source":["Apart from the core network, we can add adapters to the network design to allow it to handle different tasks common for diffusion training. Here we'll use [classifer free guidance](https://arxiv.org/abs/2207.12598), this is a technique used for conditional generative models that improves image-prompt alignment."]},{"cell_type":"code","metadata":{"originalKey":"f0959cb8-bd48-4509-a359-d493ce7ca838","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"a1677914-1ad2-4835-99ce-cbade342a2ec","executionStartTime":1697567117884,"executionStopTime":1697567117926,"customOutput":null},"source":["decoder = CFGuidance(unet, # Model being adapted\n"," {\"digit\": 768}, # Define conditional inputs name and size\n"," guidance=2.0) # How strong to to increase image-prompt alignment at the expense of image diversity"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"a732e42c-2955-4830-a8aa-d4b3c77b2926","showInput":false,"customInput":null},"source":["Step 2 is to define the schedule. The schedule is the [diffusion process](https://arxiv.org/abs/2006.11239) which describes the amount of noise added to the image at each diffusion step. Using a Gaussian schedule, we sample Gaussian noise at every step defined as $\\mathcal{N}(0, \\beta)$. Here we define the schedule with a linearly increasing schedule of variance ($\\beta$) values."]},{"cell_type":"code","metadata":{"originalKey":"44ecb076-7ac9-471b-a056-0b12e476dcc7","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"932f3a2e-85dd-42cb-86f8-36f36dd9db3f","executionStartTime":1697567120423,"executionStopTime":1697567120460,"customOutput":null},"source":["schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) # Helper function for vairance values"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"a487d070-5487-4dd4-a06f-e299af9d5a97","showInput":false,"customInput":null},"source":["Step 3 is the Predictor which determines the output of the denoising network. Here we train the model to output the noise to be removed at each step. The predictor contains the methods to convert the model output into the cleaned image."]},{"cell_type":"code","metadata":{"originalKey":"be11e1a6-1898-4d41-8c38-c8d9afbbb8c6","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"a74da193-2bd6-4986-afac-12320abc9dd5","executionStartTime":1697567123121,"executionStopTime":1697567123123,"customOutput":null},"source":["predictor = NoisePredictor(schedule, # Scale of noise at each step\n"," lambda x: torch.clamp(x, -1, 1)) # Min and max image values"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"5669ffc8-595a-4ee4-b3f5-0a4bdfa16bc8","showInput":false,"customInput":null},"source":["Step 4 is to define the Sampler. The **Sampler** applies the denoising **Network** for each step of the diffusion **Schedule** using the **Predictor** to fully denoise an image. Here we use the [Diffusion Probabilistic Implicit Models](https://arxiv.org/abs/2006.11239) sampler which is the original diffusion sampler."]},{"cell_type":"code","metadata":{"originalKey":"0443e8db-ee28-412d-93b7-ee1c7b2dc502","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"14f38043-bfc9-4fd9-975a-71aa7f0d8f64","executionStartTime":1697567125555,"executionStopTime":1697567125611,"customOutput":null},"source":["eval_steps = torch.linspace(0, 999, 250, dtype=torch.int) # Diffusion steps to sample at inference\n","decoder = DDPModule(decoder, schedule, predictor, eval_steps) # Sampler"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"ad8846d4-734e-4c47-a30f-a850c41fe839","showInput":false,"customInput":null},"source":["Finally, to condition this model on MNIST digits, lets define a simple conditional encoder to convert digits to conditional embeddings:"]},{"cell_type":"code","metadata":{"originalKey":"c989ea0d-b0f0-48d4-8719-fb172ca24a2f","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"4d725a5f-36b6-4d0d-9ea4-c92f0ed525c7","executionStartTime":1697567127236,"executionStopTime":1697567127237,"customOutput":null},"source":["encoder = nn.Embedding(10, # Number of digits\n"," 768) # Embed size"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"0cc75105-aff5-408e-82e1-41bde3939d98","showInput":false,"customInput":null},"source":["# Training"]},{"cell_type":"markdown","metadata":{"originalKey":"de8198fc-4109-4b03-b60d-41cf3a84f024","showInput":false,"customInput":null},"source":["For data, we need to define the transforms, a dataset, and a dataloader. For training a diffusion model, you sample each data point from the diffusion process. The RandomDiffusionSteps transform takes in your data point and samples a random diffusion step and applies noise to the data accordingly.?"]},{"cell_type":"code","metadata":{"originalKey":"4322fcc6-1d34-4d4a-a3cd-9e1c8e0893ba","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"7eeab393-57b2-4935-90e4-52f344b7e89f","executionStartTime":1697567129475,"executionStopTime":1697567129511,"customOutput":null},"source":["from torchvision.transforms import Compose, Resize, ToTensor, Lambda\n","\n","diffusion_transform = RandomDiffusionSteps(schedule, batched=False) # Diffusion transform given schedule\n","transform = Compose([Resize(32), # Resize MNIST image for network\n"," ToTensor(),\n"," Lambda(lambda x: 2*x - 1), # Scale image to [-1, 1]\n"," Lambda(lambda x: diffusion_transform({\"x\": x}))]) # Apply diffusion transform"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"774ca769-7e5c-4bb4-b952-0bec7e54318e","showInput":false,"customInput":null},"source":["Load Dataset"]},{"cell_type":"code","metadata":{"originalKey":"ae22ab08-3cb0-4919-a99f-b477faeeaf6f","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"d45955c2-9471-4903-a35f-ddb02235ad24","executionStartTime":1697567131735,"executionStopTime":1697567131830,"customOutput":null},"source":["from torchvision.datasets import MNIST\n","from torch.utils.data import DataLoader\n","\n","train_dataset = MNIST(\"mnist\", train=True, download=True, transform=transform)\n","train_dataloader = DataLoader(train_dataset, batch_size=192, shuffle=True, num_workers=2, pin_memory=True)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"7a996d24-3712-413f-9927-444485a42a67","showInput":false,"customInput":null},"source":["For DDPM we'll train using [diffusion hybrid loss](https://arxiv.org/abs/2102.09672) between the model output and added noise. This loss measures the distance between the model output and the target as well as the KL Divergence between the predicted noise variance and actual."]},{"cell_type":"code","metadata":{"originalKey":"d6a009a0-d436-4235-bf72-fd3b2967247c","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"55343280-281e-44a8-9543-76855f39b3fa","executionStartTime":1697567134478,"executionStopTime":1697567134514,"customOutput":null},"source":["h_loss = DiffusionHybridLoss(schedule)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"387d4078-74d1-42af-8250-ede2613f41e9","showInput":false,"customInput":null},"source":["We can then choose our favorite optimizer and optionally use a scaler for mixed precision training."]},{"cell_type":"code","metadata":{"originalKey":"1790abf5-8ac8-4ba0-95ce-9ab13d6e65df","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"f8849cb5-5ff3-400c-8332-b845a77140f8","executionStartTime":1697567138677,"executionStopTime":1697567138930,"customOutput":null},"source":["encoder.to(device)\n","decoder.to(device)\n","\n","optimizer = torch.optim.AdamW(\n"," [{\"params\": encoder.parameters()}, {\"params\": decoder.parameters()}], lr=0.0001\n",")\n","scaler = torch.cuda.amp.GradScaler()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"a2d856ee-a21a-491f-bf7e-430ec4cf6b8b","showInput":false,"customInput":null},"source":["# Train"]},{"cell_type":"markdown","metadata":{"originalKey":"b628612a-23f2-4f16-8f21-c49f1a43660a","showInput":false,"customInput":null},"source":["Here is a simple standard Pytorch training loop, just with mixed precision added in for faster training. The diffusion model has a fixed signature\n","\n","$model(x_t, t, cond_{dict})$\n","\n","When training, the model only computes a single denoising step per input. The model is also given a dictionary of conditional inputs that the Adapters and underlying network have access to for conditional generation."]},{"cell_type":"code","metadata":{"originalKey":"2bb185a4-daad-43c0-bfec-66d5c3cdcb43","showInput":true,"customInput":null,"collapsed":true,"requestMsgId":"65b1e017-3aa3-4872-bd41-4d9d29ea36bd","executionStartTime":1697560624621,"executionStopTime":1697560626826,"customOutput":null},"source":["epochs = 5\n","\n","encoder.train()\n","decoder.train()\n","for e in range(epochs):\n","\tfor sample in (pbar := tqdm(train_dataloader)):\n","\t\tx, d = sample\n","\t\tx0, xt, noise, t, d = x[\"x\"].to(device), x[\"xt\"].to(device), x[\"noise\"].to(device), x[\"t\"].to(device), d.to(device)\n","\t\toptimizer.zero_grad()\n","\n","\t\twith torch.autocast(device):\n","\t\t\td = encoder(d)\n","\t\t\tout = decoder(xt, t, {\"digit\": d})\n","\t\t\tloss = h_loss(out.prediction, noise, out.mean, out.log_variance, x0, xt, t)\n","\n","\t\tscaler.scale(loss).backward()\n","\t\tscaler.step(optimizer)\n","\t\tscaler.update()\n","\n","\t\tpbar.set_description(f'{e+1}| Loss: {loss.item()}')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"originalKey":"1256a354-bdab-4aca-933c-5fc92235528c","showInput":false,"customInput":null},"source":["# Eval"]},{"cell_type":"markdown","metadata":{"originalKey":"8c8e550f-9745-4aba-8d66-09320740b95f","showInput":false,"customInput":null},"source":["Likewise, eval is done as a standard Pytorch eval model call. While in train mode the model computes a single denoising step and outputs the raw model output, in eval mode the model computes steps $t, ..., 0$ and returns the denoised data.\n","\n","In eval, if no timestep is provided, it's assumed to be the largest timestep $T$ and the input is $x_T$. Since $x_T$ is equivalent to random noise, you sample the input from torch.randn. "]},{"cell_type":"code","metadata":{"originalKey":"f6d01886-6b6d-41c5-ac66-a497bcedf181","showInput":true,"customInput":null,"collapsed":false,"requestMsgId":"f836b0eb-5fa0-4d6f-b899-85dc9fe0a1b7","executionStartTime":1697567144970,"executionStopTime":1697567145712,"customOutput":null},"source":["encoder.eval()\n","decoder.eval()\n","\n","digit = torch.as_tensor([i for i in range(1,10)]).to(device) # Generate digits 0 to 9\n","noise = torch.randn(size=(9,1,32,32)).to(device) # Sample 9 inputs\n","\n","with torch.no_grad():\n"," d = encoder(digit)\n"," imgs = decoder(noise, conditional_inputs={\"digit\": d})\n","\n","img_grid = torchvision.utils.make_grid(imgs, 3)\n","img = F.to_pil_image((img_grid + 1) / 2)\n","img.resize((288, 288))"],"execution_count":null,"outputs":[]}]}

0 comments on commit 2ddb8cd

Please sign in to comment.