1. CIFAR-10

This tutorial demonstrates how to build a simple diffusion model with Azula, and train it to generate CIFAR-10 images.

# !pip install datasets
import torch
import torch.nn as nn

from datasets import load_dataset
from einops import rearrange
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import RandomHorizontalFlip
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm

from azula.denoise import PreconditionedDenoiser
from azula.nn.embedding import SineEncoding
from azula.nn.unet import UNet
from azula.noise import VPSchedule
from azula.sample import DDIMSampler

device = "cuda"

1.1. Data

def transform(rows):
    rows["img"] = list(map(to_tensor, rows["img"]))
    rows["label"] = list(map(torch.as_tensor, rows["label"]))
    return rows


cifar10 = load_dataset("cifar10", split="train", keep_in_memory=True)
cifar10 = cifar10.with_transform(transform)

to_pil_image(cifar10[0]["img"]).resize((256, 256), Image.NEAREST)
../_images/e2d83f67d9ec283f987546d28b9aaf39ff57d8baee6712c0e141579600cc2681.png
def preprocess(x):
    return 2 * x - 1
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)

1.2. Diffusion model

class UNetWrapper(nn.Module):
    def __init__(self):
        super().__init__()

        self.unet = UNet(
            3, 3, 256, hid_channels=[128, 256, 384], hid_blocks=[3, 3, 3], attention_heads={2: 1}
        )
        self.time_encoding = SineEncoding(256)
        self.label_embedding = nn.Embedding(10, 256)

    def forward(self, x_t, t, label):
        emb = self.time_encoding(t) + self.label_embedding(label)

        x_t = rearrange(x_t, "B (C H W) -> B C H W", C=3, H=32, W=32)
        x_0 = self.unet(x_t, emb)
        x_0 = rearrange(x_0, "B C H W -> B (C H W)")

        return x_0


denoiser = PreconditionedDenoiser(backbone=UNetWrapper(), schedule=VPSchedule()).to(device)

1.3. Training

optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-4)
averaged = torch.optim.swa_utils.AveragedModel(
    model=denoiser,
    multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999),
)
augment = RandomHorizontalFlip()
batch_size = 256

loader = DataLoader(
    cifar10,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=4,
    persistent_workers=True,
)

for _ in (bar := tqdm(range(64))):
    losses = []

    for batch in loader:
        x = batch["img"].to(device)
        label = batch["label"].to(device)

        x = augment(x)
        x = preprocess(x).reshape(batch_size, -1)

        t = torch.rand((batch_size,), device=device)

        loss = denoiser.loss(x, t, label=label).mean() / (3 * 32 * 32)
        loss.backward()
        losses.append(loss.detach())

        optimizer.step()
        optimizer.zero_grad()

        averaged.update_parameters(denoiser)

    bar.set_postfix(loss=torch.stack(losses).mean().item())

denoiser = averaged.module
 19%|█▉        | 12/64 [09:36<41:34, 47.96s/it, loss=0.688]

1.4. Evaluation

sampler = DDIMSampler(denoiser, steps=256).to(device)

label = torch.randint(10, size=(), device=device)

x1 = sampler.init((1, 3 * 32 * 32))
x0 = sampler(x1, label=label)

to_pil_image(postprocess(x0).reshape(3, 32, 32)).resize((256, 256), Image.NEAREST)