Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples $10 imes$ to $50 imes$ faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.
Source: Denoising Diffusion Implicit Models (2020-10-06). See: paper link.
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
class DiffusionModel():
def __init__(self, T, model, device):
self.T = T
self.function_approximator = model.to(device)
self.device = device
self.beta = torch.linspace(1e-4, 0.02, T).to(device)
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
def training(self, batch_size, optimizer):
pass # See https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/blob/main/Denoising_Diffusion_Probabilistic_Models/diffusion_models.py#L31
@torch.no_grad()
def sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), use_tqdm=True):
pass # See https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/blob/main/Denoising_Diffusion_Probabilistic_Models/diffusion_models.py#L54
@torch.no_grad()
def ddim_sampling(self, n_samples=1, image_channels=1, img_size=(32, 32),
n_steps=50, use_tqdm=True):
step_size = self.T // n_steps
x = torch.randn((n_samples, image_channels, img_size[0], img_size[1]),
device=self.device)
progress_bar = tqdm if use_tqdm else lambda x: x
for i in progress_bar(range(n_steps)):
t = self.T - i * step_size
t_tensor = torch.ones(n_samples, dtype=torch.long, device=self.device) * t
alpha_bar_t = self.alpha_bar[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
alpha_bar_prev = self.alpha_bar[t - step_size - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if t > step_size else torch.tensor(1.0).to(self.device)
# Predicted noise
eps_pred = self.function_approximator(x, t_tensor - 1)
# Predicted x_0
x0_pred = (x - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
# Direction pointing to xt
dir_xt = torch.sqrt(1 - alpha_bar_prev) * eps_pred
# Update rule
x = torch.sqrt(alpha_bar_prev) * x0_pred + dir_xt
return x
if __name__ == "__main__":
model = torch.load('model_ddpm_mnist') # Model from https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/blob/main/Denoising_Diffusion_Probabilistic_Models
diffusion_model = DiffusionModel(1000, model, 'cuda')
nb_images = 81
samples = diffusion_model.ddim_sampling(n_samples=nb_images, use_tqdm=True)
plt.figure(figsize=(17, 17))
for i in range(nb_images):
plt.subplot(9, 9, 1 + i)
plt.axis('off')
plt.imshow(samples[i].squeeze(0).clip(0, 1).data.cpu().numpy(), cmap='gray')
plt.savefig('Imgs/ddim_samples.png')
plt.close()
python implementation Denoising Diffusion Implicit Models in 100 lines
2020-10-06