2. Guidance

This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.

# !git clone --depth 1 --single-branch https://github.com/openai/guided-diffusion
# !pip install datasets
import sys
import torch

sys.path.append("guided-diffusion")

from azula.guidance import DPSSampler, MMPSDenoiser
from azula.plugins import adm
from azula.sample import DDIMSampler
from datasets import load_dataset
from PIL import Image
from torchvision.transforms.functional import resize, to_pil_image, to_tensor

device = "cuda"
_ = torch.manual_seed(0)

2.1. Pre-trained diffusion model

denoiser = adm.load_model("imagenet_256x256").to(device)
Downloading https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt to /root/.cache/azula/hub/https.openaipublic.blob.core.windows.net.diffusion.jul-2021.256x256_diffusion_uncond.pt
100%|██████████| 2.06G/2.06G [06:55<00:00, 5.32MB/s]
def preprocess(x):
    return 2 * x - 1
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)
sampler = DDIMSampler(denoiser, steps=64).to(device)

x1 = sampler.init((1, 3 * 256 * 256))
x0 = sampler(x1)

to_pil_image(postprocess(x0).reshape(3, 256, 256))
../_images/8d7487d3b4d5f8c86a6c69c522f33e383e6dbe8db4d68d3e688771b86b6c073e.png

2.2. Measurement

def crop(image):
    return image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
imagenet = load_dataset("ILSVRC/imagenet-1k", streaming=True, split="test", trust_remote_code=True)

x_ref = next(iter(imagenet))["image"]
x_ref = crop(x_ref)
x_ref
../_images/ffbb9b08d8e2867b5cae2681cd3ef39b093e183c12f2d4abe7106ed104c06299.png
x = preprocess(to_tensor(x_ref))

y = resize(x, (64, 64))
y = y + 0.01 * torch.randn_like(y)

to_pil_image(postprocess(y)).resize((256, 256), Image.NEAREST)
../_images/6f30fdae0533c9b4986cfa81df2477560c2736f21f048c2b93d09b9f93a70cd1.png

2.3. Diffusion posterior sampling (DPS)

def A(x):
    return resize(x.unflatten(-1, (3, 256, 256)), (64, 64)).flatten(-3)
cond_sampler = DPSSampler(denoiser, y=y.flatten(), A=A, zeta=0.5, steps=1000).to(device)

x1 = cond_sampler.init((1, 3 * 256 * 256))
x0 = cond_sampler(x1)

to_pil_image(postprocess(x0).reshape(3, 256, 256))
../_images/3a7b4ca1aa91973cdfd78e7092472d8d336e7ce87802d95e6ee80f9af4eb10b0.png

2.4. Moment matching posterior sampling (MMPS)

cond_denoiser = MMPSDenoiser(denoiser, y=y.flatten(), A=A, var_y=0.01**2, iterations=3)
cond_sampler = DDIMSampler(cond_denoiser, steps=64).to(device)

x1 = cond_sampler.init((1, 3 * 256 * 256))
x0 = cond_sampler(x1)

to_pil_image(postprocess(x0).reshape(3, 256, 256))
../_images/77efecb9a7e0723525040a02a884e64fb7f17c0c9f4fe8e1bd6c02462f2218c9.png