#!/usr/bin/env python
# coding: utf-8
# # Denoising Diffusion Probabilistic Model (DDPM)
#
# This notebook focuses on the implementation from scratch of DDPM [[1](https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf)] and is inspired by three great blogs: [[2](https://huggingface.co/blog/annotated-diffusion)], [[3](https://benanne.github.io/2022/05/26/guidance.html)], and [[10](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)]
# ## Setup
#
# Diffusion models roughly consist of two parts:
#
# - A predefined **forward** diffusion process $q(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})$ of our choosing, that gradually adds Gaussian noise to an image until it becomes pure noise.
# - A **reverse** denoising diffusion process $q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)$ , approximated by a neural network, which gradually denoises an image starting from pure noise.
#
#
# ### Forward process
#
# **Nice property:** Let $\alpha_t = 1 - \beta_t$ and $\bar{\alpha}_t = \prod_{i=1}^t \alpha_i$
#
# \begin{align}
# \boldsymbol{x}_t & = \sqrt{\alpha_t} \boldsymbol{x}_{t-1} + \sqrt{1 - \alpha_t}\boldsymbol{z}_{t-1} \;\;\; \text{where} \;\boldsymbol{z}_{t-1}, \boldsymbol{z}_{t-2}, \dots, \mathcal{N}(\mathbf{0}, \mathbf{I}).\\
# & = \sqrt{\alpha_t} \lbrack \sqrt{\alpha_{t-1}} \boldsymbol{x}_{t-2} + \sqrt{1- \alpha_{t-1}}\boldsymbol{z}_{t-2} \rbrack + \sqrt{1 - \alpha_t}\boldsymbol{z}_{t-1}\\
# & = \sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2} + \sqrt{\alpha_t - \alpha_t \alpha_{t-1}}\boldsymbol{z}_{t-2} + \sqrt{1 - \alpha_t}\boldsymbol{z}_{t-1}\\
# & = \sqrt{\alpha_t \alpha_{t-1}} \boldsymbol{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}}\boldsymbol{\bar{z}}_{t-2} \;\;\; \text{where} \; \boldsymbol{\bar{z}}_{t-2} \; \text{merges two Gaussians}^1 \\
# & = \dots \\
# & = \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\bar{z}_0} \;\;\; \text{where} \; \boldsymbol{\bar{z}_0} = \boldsymbol{z} \\
# q(\boldsymbol{x}_t|\boldsymbol{x}_0) & = \mathcal{N}(\sqrt{\bar{\alpha}_t} \boldsymbol{x}_0, 1 - \bar{\alpha}_t)
# \end{align}
#
# $^1$ the sum of two Gaussian variables $\boldsymbol{z}_1 \sim \mathcal{N}(\mathbf{0}, \sigma_1^2 \mathbf{I})$ and $\boldsymbol{z}_2 \sim \mathcal{N}(\mathbf{0}, \sigma_2^2 \mathbf{I})$ is a new variable $\bar{\boldsymbol{z}} \sim \mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2) \mathbf{I})$. In our case, $\boldsymbol{z}_{t-1} \sim \mathcal{N}(\mathbf{0}, (1-\alpha_t) \mathbf{I})$ and $\boldsymbol{z}_{t-2} \sim \mathcal{N}(\mathbf{0}, (\alpha_t - \alpha_t \alpha_{t-1}) \mathbf{I})$, so $\boldsymbol{\bar{z}}_{t-2} \sim \mathcal{N}(\mathbf{0}, (1 - \alpha_t \alpha_{t-1}) \mathbf{I})$.
# In[1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
# First, we define the scheduler to compute $\beta_t$. The simplest is the linear scheduler, but more advanced schedulers can give better results.
# In[2]:
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
# In[3]:
timesteps = 200
# compute betas
betas = linear_beta_schedule(timesteps=timesteps)
# compute alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
# calculations for the forward diffusion q(x_t | x_{t-1})
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# Let's illustrate how noise is added to a sample image at each time step of the diffusion process.
# In[4]:
from PIL import Image
image = Image.open("img/lion_short.png")
image
# Next, we resize the image, rescale it in $[-1, 1]$, and convert it to a PyTorch tensor.
# In[5]:
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
image_size = 128
transform = Compose([
Resize(image_size),
ToTensor(), # turn into Numpy array of shape HWC, divide by 255
Lambda(lambda t: (t * 2) - 1), # [0,1] --> [-1,1]
])
x_start = transform(image).unsqueeze(0)
x_start.shape
# We also define the reverse transform, which maps a PyTorch tensor with in $[−1,1]$ back into a PIL image.
# In[6]:
reverse_transform = Compose([
Lambda(lambda t: (t + 1) / 2),
Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
Lambda(lambda t: t * 255.),
Lambda(lambda t: t.numpy().astype(np.uint8)),
ToPILImage(),
])
reverse_transform(x_start.squeeze())
# We can now define the forward diffusion process.
# In[7]:
# utility function to extract the appropriate t index for a batch of indices.
# e.g., t=[10,11], x_shape=[b,c,h,w] --> a.shape = [2,1,1,1]
# e.g., t=[7,12,15,20], x_shape=[b,h,w] --> a.shape = [4,1,1]
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start) # z (it does not depend on t!)
# adjust the shape
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
# Let's test on a specific time step, $t=20$:
# In[8]:
def get_noisy_image(x_start, t):
x_noisy = q_sample(x_start, t=t) # add noise
noisy_image = reverse_transform(x_noisy.squeeze()) # turn back into PIL image
return noisy_image
# In[9]:
t = torch.tensor([19])
get_noisy_image(x_start[0], t)
# In[10]:
def plot_seq(imgs, with_orig=False, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [image] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
# In[11]:
plot_seq([get_noisy_image(x_start, torch.tensor([t])) for t in [1, 50, 100, 150, 199]])
# ### Backward process
#
#
# - If $\beta_t$ is small enough, $q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)$ will be a Gaussian $\mathcal{N}\big(\tilde{\mu}(\boldsymbol{x}_t), \tilde{\sigma}(\boldsymbol{x}_t)\big)$
# - However, we cannot compute $\tilde{\mu}(\boldsymbol{x}_t)$ and $\tilde{\sigma}(\boldsymbol{x}_t)$ analytically, because requires knowing the true data distribution $p(\boldsymbol{x})$
# $$ q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)q(\boldsymbol{x}_{t-2}|\boldsymbol{x}_{t-1})\dots \rightarrow p(\boldsymbol{x}_0)$$
# - We can use a NN to learn $\tilde{\mu}(\boldsymbol{x}_t)$ and $\tilde{\sigma}(\boldsymbol{x}_t)$
# ### Simplification
#
# Represent $\tilde{\sigma}(\boldsymbol{x}_t)$ with a time dependent constant $\tilde{\beta}_t$
#
# ### Manipulation
#
# Condition on $\boldsymbol{x}_0$ and obtain
# $q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t$ ,
# $\boldsymbol{x}_0$
# $)$
# $ = \mathcal{N} \big( $
# $\tilde{\mu}(\boldsymbol{x}_t, \boldsymbol{x}_0)$ ,
# $\tilde{\beta}_t$
# $\big) = \mathcal{N} \big( $
# $\tilde{\mu}_t$ ,
# $\tilde{\beta}_t$
# $\big)$
#
# where
#
#
# $\tilde{\beta}_t = \frac{1 - \bar \alpha_{t-1}}{1 - \bar \alpha_{t}} \beta_t$
#
#
# and
#
#
# $\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big(\boldsymbol{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}$
#
#
# $\boldsymbol{\tilde z}_t$
#
#
# $\Big)$
#
#
# ### NN Model
#
# - Take a look at
# $\tilde{\beta}_t = \frac{1 - \bar \alpha_{t-1}}{1 - \bar \alpha_{t}} \beta_t$
#
# and
#
# $\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big(\boldsymbol{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}$
#
#
# $\boldsymbol{\tilde z}_t$
#
#
# $\Big)$
#
# - The only thing we do not know is $\boldsymbol{\tilde z}_t$ .
# - We approximate $\boldsymbol{\tilde z}_t$ with a NN.
#
#
#
#
# As NN, we use a Unet.
# ![unet.png](img/unet.png)
# In[12]:
# Let's look at how the time embeddings look like
from scripts.unet import SinusoidalPositionEmbeddings
time_emb = SinusoidalPositionEmbeddings(100)
t1 = time_emb(torch.tensor([10]))
t2 = time_emb(torch.tensor([12]))
t3 = time_emb(torch.tensor([30]))
fig, ax = plt.subplots(figsize=(5,3))
ax.plot(t1.numpy()[0], label='t=10')
ax.plot(t2.numpy()[0], label='t=12')
ax.plot(t3.numpy()[0], label='t=30')
plt.legend();
# In[13]:
# Let's check the input and output of the U-net
from scripts.unet import Unet
temp_model = Unet(
dim=image_size,
channels=3,
dim_mults=(1, 2, 4,)
)
with torch.no_grad():
out = temp_model(x_start, torch.tensor([40]))
print(f"input shape: {x_start.shape}, output shape: { out.shape}")
# ### Loss
#
#
#
#
#
# - Loss: $ \| $ $\boldsymbol{ z}_{t-1}$ - $\boldsymbol{\tilde z}_{t}$ $\|_1 = \| $ $\boldsymbol{ z}_{t-1}$ - NN$(\boldsymbol{x}_t, t)$ $\|_1$
# - $\boldsymbol{ z}_{t-1}$ is the noise used to compute $\boldsymbol{x}_t$ in the forward process
# - Remember that to compute $\boldsymbol{x}_t$ we dropped the time index on $\boldsymbol{z}$ (nice property):
# $$ \boldsymbol{x}_t = \sqrt{\bar \alpha_t} \boldsymbol{x}_0 + \sqrt{1 - \bar \alpha_t} \boldsymbol{z} $$
# - Simply, we can sample $\boldsymbol{z}$ , compute $\boldsymbol{x}_t$, and recover $\boldsymbol{z}$ with the $NN$
# ### Training
#
#
#
# **repeat**
# - Sample an image from the training data
# - $\boldsymbol{x}_0 \sim p(\boldsymbol{x}_0)$
# - Sample $\boldsymbol{ z}$ and $t$ randomly
# - $t \sim \mathcal{U}([1, T])$
# - $\boldsymbol{ z}$ $\sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})$
# - Take gradient descent step on
# $ \nabla_\theta \|$
# $\boldsymbol{ z}$ -
#
# NN$( \sqrt{\bar \alpha_t} \boldsymbol{x}_0 + \sqrt{1 - \bar \alpha_t}$
#
#
# $\boldsymbol{z}$
#
#
# $)$
#
# $\|_1 $
#
# **until** converged
# In[14]:
def p_losses(denoise_model, x_start, t, loss_type="huber"):
# random sample z
noise = torch.randn_like(x_start)
# compute x_t
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# recover z from x_t with the NN
predicted_noise = denoise_model(x_noisy, t)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
# ### Sampling
#
# We said that
# $\boldsymbol{x}_{t-1} \sim \mathcal{N} \big( $
# $\tilde{\mu}_t$ ,
# $\tilde{\beta}_t$
# $\big)$
# where
#
# $\tilde{\beta}_t = \frac{1 - \bar \alpha_{t-1}}{1 - \bar \alpha_{t}} \beta_t$
#
# and
#
# $\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big(\boldsymbol{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}$
#
#
# $\boldsymbol{\tilde z}_t$
#
#
# $\Big)$
#
#
#
# #### *Sampling algorithm*
# $\boldsymbol{x}_T \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})$
#
# **for** $t=T-1, \dots, 0$ **do**
# - $\boldsymbol{z} \sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})$ if $t > 1$ else $\boldsymbol{z}=0$
# - $\boldsymbol{x}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \big( \boldsymbol{x}_{t} - \frac{\beta_t}{\sqrt{1 - \bar{ \alpha}_t}}$ NN$(\boldsymbol{x}_t, t)$ $\big) + \tilde \beta_t \boldsymbol{z}$
#
# **return** $\boldsymbol{x}_{0} = \mu_0$
# In[16]:
# calculations for posterior q(x_{t-1} | x_t, x_0) = q(x_{t-1} | t, x_0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # β_t
@torch.no_grad()
def p_sample(model, x, t, t_index):
# adjust shapes
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# Use the NN to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
# Draw the next sample
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape) # beta_t
noise = torch.randn_like(x) # z
return model_mean + torch.sqrt(posterior_variance_t) * noise # x_{t-1}
# In[17]:
# Sampling loop
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((shape[0],), i, device=device, dtype=torch.long), i)
imgs.append(img)
return imgs
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
# ### Train the model on Fashion MNIST
#
#
# In[18]:
from datasets import load_dataset
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
# Next, we define some basic image preprocessing on-the-fly: random horizontal flips, converstion to tensor, and rescaling in the $[-1,1]$ range.
#
# We use ``with_transform`` to apply the transformations to the elements in the dataset.
# In[19]:
from torchvision import transforms
from torch.utils.data import DataLoader
# define image transformations (e.g. using torchvision)
transform = Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
])
# define function
def transforms(examples):
examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
del examples["image"]
return examples
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
# In[20]:
from torch.optim import Adam
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
# In[21]:
from torchvision.utils import save_image
epochs = 10
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
# x0
batch_size = batch["pixel_values"].shape[0]
batch = batch["pixel_values"].to(device)
# sample t from U(0,T)
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t)
if step % 100 == 0:
print(f"Epoch: {epoch}, step: {step} -- Loss: {loss.item():.3f}")
loss.backward()
optimizer.step()
# ### Inference
#
# Once trained, we can sample from the model using ``sample`` the function defined above:
# In[22]:
# Generate 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)
# Get the last sample and normalize it in [0,1]
last_sample = (samples[-1] - samples[-1].min())/(samples[-1].max()-samples[-1].min())
grid_img = torchvision.utils.make_grid(last_sample, nrow=16)
get_ipython().run_line_magic('matplotlib', 'inline')
plt.figure(figsize = (20,10))
plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy(), cmap='gray');
# Visualize the actual denoising process with an animation
# In[25]:
import matplotlib.animation as animation
get_ipython().run_line_magic('matplotlib', 'notebook')
random_index = 6
fig = plt.figure()
ims = []
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].cpu().numpy().reshape(image_size, image_size, channels), cmap="gray", animated=True)
ims.append([im])
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()
# ## Next steps
#
# - Sample generation can be conditioned on inputs such as text (text2img) and images (img2img) [[4](https://arxiv.org/abs/2204.06125)].
# - Besides the mean, learning the standard deviation $\tilde{\sigma}(\boldsymbol{x}_t)$ of the conditional distribution helps in improving performance [[5](https://arxiv.org/abs/2102.09672)].
# - More advanced noise scheduler than the linear one can give better performance and decrease inference time [[6](https://arxiv.org/abs/2010.02502)]
# - Diffusion in the pixel space is slow. Latent diffusion, in the embedding space, can speed-up inference time greatly [[7](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html)]
# ## Conditioned generation with diffusion guidance
# ### Score-Based Generative Models
#
# - Different from DDPM, yet equivalent in nature, are the **score-based** methods [[8](https://openreview.net/forum?id=PxTIG12RRHS)]
# - *Score-based generative models* and DDPM are different perspectives of the same model family
# - In this perspective, the desnoising process corresponds to taking steps in the direction of the **score function**
# $$\nabla_x \log{p(x)}$$
# i.e., the direction where the log-likelihood increases the most
# - In the process, noise is added to avoid getting stuck between the modes of the distribution
# - This is called Stochastic Gradient Langevin Dynamics (SGLD)
# ![score_fun](img/score_fun.jpeg)
# ### Conditional models
#
# - In conditional models, we want to model $p(x|y)$
# - In the score-based perspective, it means following during diffusion the score function: $\nabla_x \log{p(x|y)}$
# - By applying Bayes' rule
# $$ p(x|y) = \frac{p(y|x)p(x)}{p(y)}$$
# that gives us
# $$ \nabla_x \log{p(x|y)} = \nabla_x \log{p(y|x)} + \nabla_x \log{p(x)} - \nabla_x \log{p(y)} = \nabla_x \log{p(y|x)} + \nabla_x \log{p(x)}$$
# - Notice that $p(y|x)$ is exactly what an image classifier tries to fit
# - If the image classifier is a neural net, it is easy to compute $\nabla_x \log{p(y|x)}$
# - We can give different contributions to the score function and the classfier
# $$\nabla_x \log{p_\gamma(x|y)} = \nabla_x \log{p(x)} + \gamma \nabla_x \log{p(y|x)},$$ where $\gamma$ is the **guidance scale** [[9](https://proceedings.neurips.cc/paper/2021/file/49ad23d1ec9fa4bd8d77d02681df5cfa-Paper.pdf)]
# - The larger $\gamma$, the larger is the influence of the conditioning signal
# - Reverting the gradient and the logarithm operations, we see that $p_\gamma(x|y) \propto p(x)p(y|x)^\gamma$
# - When $\gamma >1$, this sharpens the distribution and shift probability mass towards the values $x$ that are most likely associated with label $y$
#
#
# - ⚠️ Important: the diffusion model is trained as before and *only* at inference time is conditioned by a (pre-trained) classifier
# - Similar to train a powerful unconditional NLP model and adapt it to downstream tasks via few-shot learning
# **Limitations**
#
# ❌ Pre-trained classifiers cannot deal with noise
# - The classifier used for guidance needs to cope with high noise to provide a useful signal through the sampling process
# - This usually requires training the classifier specifically for the purpose of guidance
# - At that point, it might be easier to train a traditional conditional generative model end-to-end...
#
# ❌ Classifiers take shortcuts
# - Most of the information in the input $x$ is not accounted in predicting $y$
# - As result, taking the gradient of the classifier w.r.t. its input can yield arbitrary directions in input space
# ### Classifier-free guidance (cfg)
#
# - We said that $\nabla_x \log{p(x|y)} = \nabla_x \log{p(y|x)} + \nabla_x \log{p(x)}$
# - Now, let's express the conditioning term as a function of the conditional and unconditional score functions, both of which our diffusion model provides
# $$ \nabla_x \log{p(y|x)} = \nabla_x \log{p(x|y)} - \nabla_x \log{p(x)}$$
#
# - Then, we substitute this into the formula for classifier guidance:
# $$ \nabla_x \log{p_\gamma(x|y)} = \nabla_x \log{p(x)} + \gamma \big[ \nabla_x \log{p(x|y)} - \nabla_x \log{p(x)} \big]$$
# - We expressed the conditional score function as a combination of the conditional and the unconditional score function itself
# - We no longer need a classifier!
# - cfg is implemented by removing the conditioning signal 10-20\% of times while training (**conditioning dropout**)[[11](https://arxiv.org/abs/2112.10741)]
# - In practice, one replaces $y$ with a special input value representing the absence of conditioning information (e.g., ``""``)
# - The resulting model is now able to function both as a conditional model $p(x|y)$ and as an unconditional model $p(x)$, depending on whether the conditioning signal is provided
#
# Example: "*A stained glass window of a panda eating bamboo.*" with classifier guidance (left) and with cfg (right)
#
# ![pandas](img/pandas.png)
# Why cfg is better than classifier guidance?
# - cfg replaces the guidance from a standard classifier $\nabla_x \log{p(y|x)}$ with a "classifier" built from a generative model $\nabla_x \log{p(x|y)} - \nabla_x \log{p(x)}$.
# - Standard classifiers can take shortcuts and ignore most of the input $x$ while still obtaining competitive classification results
# - Instead, generative models are forced to consider the whole data distribution, making the gradient much more robust
# - In addtion, we only have to train a single (generative) model, and conditioning dropout is trivial to implement
#
# Limitation:
# - cfg dramatically improves adherence to the conditioning signal and the overall sample quality at the cost of diversity
# - This is a tipycal trade-off in generative models
# ### References
#
#
# [[1](https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf)] Ho, J., Jain, A., Abbeel, P., (2020). Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems.
#
# [[2](https://huggingface.co/blog/annotated-diffusion)] Rogge, N., Rasul, K., (2022). The Annotated Diffusion Model.
#
# [[3](https://benanne.github.io/2022/05/26/guidance.html)] Dieleman, S., (2022). Guidance: a cheat code for diffusion models.
#
# [[4](https://arxiv.org/abs/2204.06125)] Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., & Chen, M. (2022). Hierarchical text-conditional image generation with clip latents.
#
# [[5](https://arxiv.org/abs/2102.09672)] Nichol, A. Q., & Dhariwal, P. (2021). Improved denoising diffusion probabilistic models. In International Conference on Machine Learning.
#
# [[6](https://arxiv.org/abs/2010.02502)] Song, J., Meng, C., & Ermon, S. (2020). Denoising diffusion implicit models.
#
# [[7](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html)] Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
#
# [[8](https://openreview.net/forum?id=PxTIG12RRHS)] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations.
#
# [[9](https://proceedings.neurips.cc/paper/2021/file/49ad23d1ec9fa4bd8d77d02681df5cfa-Paper.pdf)] Dhariwal, P., & Nichol, A. (2021). Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems.
#
# [[10](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] Weng, L., (2021). What are diffusion models?
#
# [[11](https://arxiv.org/abs/2112.10741)] Nichol, A., Dhariwal, P., Ramesh, A., Shyam, P., Mishkin, P., McGrew, B., ... & Chen, M. (2021). Glide: Towards photorealistic image generation and editing with text-guided diffusion models.
#