1. MNIST¶
This tutorial demonstrates how to build a simple diffusion model with Azula, and train it to generate MNIST images.
# !pip install datasets
import torch
import torch.nn as nn
from datasets import load_dataset
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm
from azula.denoise import PreconditionedDenoiser
from azula.nn.embedding import SineEncoding
from azula.nn.unet import UNet
from azula.noise import VPSchedule
from azula.sample import DDIMSampler
device = "cuda"
1.1. Data¶
def transform(rows):
rows["image"] = list(map(to_tensor, rows["image"]))
rows["label"] = list(map(torch.as_tensor, rows["label"]))
return rows
dataset = load_dataset("mnist", split="train", keep_in_memory=True)
dataset = dataset.with_transform(transform)
to_pil_image(dataset[0]["image"]).resize((64, 64), Image.NEAREST)
def preprocess(x):
return 2 * x - 1
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
1.2. Diffusion model¶
class UNetWrapper(nn.Module):
def __init__(self, channels: int = 1, emb_features: int = 256):
super().__init__()
self.unet = UNet(
in_channels=channels,
out_channels=channels,
hid_channels=[16, 32, 64],
hid_blocks=[2, 2, 2],
attention_heads={2: 1},
mod_features=emb_features,
)
self.label_embedding = nn.Embedding(10, emb_features)
self.time_encoding = nn.Sequential(
SineEncoding(emb_features, omega=1e3),
nn.Linear(emb_features, emb_features),
nn.SiLU(),
nn.Linear(emb_features, emb_features),
)
def forward(self, x_t, log_snr_t, label):
emb = self.time_encoding(log_snr_t) + self.label_embedding(label)
x_0 = self.unet(x_t, emb)
return x_0
denoiser = PreconditionedDenoiser(backbone=UNetWrapper(), schedule=VPSchedule()).to(device)
1.3. Training¶
optimizer = torch.optim.Adam(denoiser.parameters(), lr=3e-4)
loader = DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=2,
)
for _ in (bar := tqdm(range(64))):
losses = []
for batch in loader:
x = batch["image"].to(device=device)
label = batch["label"].to(device=device)
x = preprocess(x)
t = torch.rand(len(x), device=device)
loss = denoiser.loss(x, t, label=label)
loss.backward()
losses.append(loss.detach())
optimizer.step()
optimizer.zero_grad()
bar.set_postfix(loss=torch.stack(losses).mean().item())
100%|██████████| 64/64 [18:35<00:00, 17.42s/it, loss=0.151]
1.4. Evaluation¶
sampler = DDIMSampler(denoiser, steps=64).to(device)
label = torch.tensor(7, device=device)
x1 = sampler.init((1, 1, 28, 28))
x0 = sampler(x1, label=label)
to_pil_image(postprocess(x0).squeeze()).resize((64, 64), Image.NEAREST)