2. Guidance

This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.

# !git clone --depth 1 --single-branch https://github.com/openai/guided-diffusion
import sys
import torch

sys.path.append("guided-diffusion")

from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor
from urllib.request import urlretrieve

from azula.guidance import DPSSampler, MMPSDenoiser
from azula.plugins import adm
from azula.sample import DDIMSampler

device = "cuda"
_ = torch.manual_seed(0)

2.1. Pre-trained diffusion model

denoiser = adm.load_model("imagenet_256x256").to(device)
denoiser = denoiser.requires_grad_(False)  # reduce memory overhead
Loading from /home/frozet/.cache/azula/hub/https.openaipublic.blob.core.windows.net.diffusion.jul-2021.256x256_diffusion_uncond.pt
def preprocess(x):
    return 2 * x - 1
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)
sampler = DDIMSampler(denoiser, steps=64).to(device)

x1 = sampler.init((1, 3, 256, 256))
x0 = sampler(x1)

to_pil_image(postprocess(x0).squeeze()).resize((256, 256), Image.NEAREST)
../_images/65f8289935d270d8c73648fae211be5b0d849090e350bec26ba0045264cebf33.png

2.2. Measurement

image, _ = urlretrieve("https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg")
image = Image.open(image).convert("RGB")
image = image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
image
../_images/2bc38707721e20eaf6adb9c2ffeaeea27bddefce873f53faeb373f0980f78b38.png
x = preprocess(to_tensor(image))

y = x[..., ::4, ::4]
y = y + 0.01 * torch.randn_like(y)

to_pil_image(postprocess(y)).resize(image.size, Image.NEAREST)
../_images/f0149273e7be8e8e179879d4fa08f6794a4e2788da1f05bc3da4d02f659fa765.png

2.3. Diffusion posterior sampling (DPS)

def A(x):
    return x[..., ::4, ::4].flatten(-3)
cond_sampler = DPSSampler(denoiser, y=y.flatten(), A=A, steps=256).to(device)

x1 = cond_sampler.init((1, 3, 256, 256))
x0 = cond_sampler(x1)

to_pil_image(postprocess(x0).squeeze())
../_images/69db6de9eba5763b0915621082c8192313e6e5ae9d076a99f439ca4d89c9bdab.png

2.4. Moment matching posterior sampling (MMPS)

cond_denoiser = MMPSDenoiser(denoiser, y=y.flatten(), A=A, var_y=0.01**2, iterations=3)
cond_sampler = DDIMSampler(cond_denoiser, steps=16, eta=1.0).to(device)

x1 = cond_sampler.init((1, 3, 256, 256))
x0 = cond_sampler(x1)

to_pil_image(postprocess(x0).squeeze())
../_images/079d1f5f21a1a4c9e5cfdc7571803b287c889736bc16452ea53447bcbcfe0164.png