2. Guidance

This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.

# !git clone --depth 1 --single-branch https://github.com/openai/guided-diffusion
# !pip install datasets
import sys
import torch

sys.path.append("guided-diffusion")

from azula.guidance import DPSSampler, MMPSDenoiser
from azula.plugins import adm
from azula.sample import DDIMSampler
from datasets import load_dataset
from PIL import Image
from torchvision.transforms.functional import resize, to_pil_image, to_tensor

device = "cuda"
_ = torch.manual_seed(0)

2.1. Pre-trained diffusion model

denoiser = adm.load_model("imagenet_256x256_uncond").to(device)
def preprocess(x):
    return 2 * x - 1
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)
sampler = DDIMSampler(denoiser, steps=64).to(device)

z = torch.randn((1, 3 * 256 * 256), device=device)
x = sampler(z)

to_pil_image(postprocess(x).reshape(3, 256, 256))
../_images/5f33d614307569a15eb78b1ce98da7e06e608ba4094d176b0ae7f29de314b807.png

2.2. Measurement

def crop(image):
    return image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
imagenet = load_dataset("ILSVRC/imagenet-1k", streaming=True, split="test", trust_remote_code=True)

x_ref = next(iter(imagenet))["image"]
x_ref = crop(x_ref)
x_ref
../_images/ffbb9b08d8e2867b5cae2681cd3ef39b093e183c12f2d4abe7106ed104c06299.png
x = preprocess(to_tensor(x_ref))

y = resize(x, (64, 64))
y = y + 0.01 * torch.randn_like(y)

to_pil_image(postprocess(y)).resize((256, 256), Image.NEAREST)
../_images/6f30fdae0533c9b4986cfa81df2477560c2736f21f048c2b93d09b9f93a70cd1.png

2.3. Diffusion posterior sampling (DPS)

def A(x):
    return resize(x.unflatten(-1, (3, 256, 256)), (64, 64)).flatten(-3)
cond_sampler = DPSSampler(denoiser, y=y.flatten(), A=A, zeta=0.5, steps=1000).to(device)

z = torch.randn((1, 3 * 256 * 256), device=device)
x = cond_sampler(z)

to_pil_image(postprocess(x).reshape(3, 256, 256))
../_images/1ea2f29510cea98bc3ca7a13c2a740be0b0ba175f8343a0665dbd7b92824119c.png

2.4. Moment matching posterior sampling (MMPS)

cond_denoiser = MMPSDenoiser(denoiser, y=y.flatten(), A=A, var_y=0.01**2, iterations=3)
cond_sampler = DDIMSampler(cond_denoiser, steps=64).to(device)

z = torch.randn((1, 3 * 256 * 256), device=device)
x = cond_sampler(z)

to_pil_image(postprocess(x).reshape(3, 256, 256))
../_images/c24a89974862cd059f27344f083d9de69898a3a1236044784a4439bc0bde350f.png