2. Guidance¶
This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.
# !git clone --depth 1 --single-branch https://github.com/openai/guided-diffusion
# !pip install datasets
import sys
import torch
sys.path.append("guided-diffusion")
from datasets import load_dataset
from PIL import Image
from torchvision.transforms.functional import resize, to_pil_image, to_tensor
from azula.guidance import DPSSampler, MMPSDenoiser
from azula.plugins import adm
from azula.sample import DDIMSampler
device = "cuda"
_ = torch.manual_seed(0)
2.1. Pre-trained diffusion model¶
denoiser = adm.load_model("imagenet_256x256").to(device)
Downloading https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt to /root/.cache/azula/hub/https.openaipublic.blob.core.windows.net.diffusion.jul-2021.256x256_diffusion_uncond.pt
100%|██████████| 2.06G/2.06G [06:55<00:00, 5.32MB/s]
def preprocess(x):
return 2 * x - 1
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
sampler = DDIMSampler(denoiser, steps=64).to(device)
x1 = sampler.init((1, 3 * 256 * 256))
x0 = sampler(x1)
to_pil_image(postprocess(x0).reshape(3, 256, 256))
2.2. Measurement¶
def crop(image):
return image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
imagenet = load_dataset("ILSVRC/imagenet-1k", streaming=True, split="test", trust_remote_code=True)
x_ref = next(iter(imagenet))["image"]
x_ref = crop(x_ref)
x_ref
x = preprocess(to_tensor(x_ref))
y = resize(x, (64, 64))
y = y + 0.01 * torch.randn_like(y)
to_pil_image(postprocess(y)).resize((256, 256), Image.NEAREST)
2.3. Diffusion posterior sampling (DPS)¶
def A(x):
return resize(x.unflatten(-1, (3, 256, 256)), (64, 64)).flatten(-3)
cond_sampler = DPSSampler(denoiser, y=y.flatten(), A=A, zeta=0.5, steps=1000).to(device)
x1 = cond_sampler.init((1, 3 * 256 * 256))
x0 = cond_sampler(x1)
to_pil_image(postprocess(x0).reshape(3, 256, 256))
2.4. Moment matching posterior sampling (MMPS)¶
cond_denoiser = MMPSDenoiser(denoiser, y=y.flatten(), A=A, var_y=0.01**2, iterations=3)
cond_sampler = DDIMSampler(cond_denoiser, steps=64).to(device)
x1 = cond_sampler.init((1, 3 * 256 * 256))
x0 = cond_sampler(x1)
to_pil_image(postprocess(x0).reshape(3, 256, 256))