2. Guidance

This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.

# !git clone --depth 1 --single-branch https://github.com/openai/guided-diffusion
import io
import requests
import sys
import torch

sys.path.append("guided-diffusion")

from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.utils import make_grid

from azula.guidance import (
    DiffPIRDenoiser,
    DPSSampler,
    JFPSDenoiser,
    MMPSDenoiser,
    PGDMSampler,
    TMPDenoiser,
)
from azula.linalg.covariance import DiagonalCovariance, PreconditionedCovariance
from azula.plugins import adm
from azula.sample import DDIMSampler

device = "cuda"
_ = torch.manual_seed(42)
def preprocess(x):
    return 2 * x - 1
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)

2.1. Pre-trained diffusion model

denoiser = adm.load_model("imagenet_256x256").to(device)
denoiser = denoiser.requires_grad_(False)  # reduce memory overhead
Loading from /home/frozet/.cache/azula/hub/https.openaipublic.blob.core.windows.net.diffusion.jul-2021.256x256_diffusion_uncond.pt
sampler = DDIMSampler(denoiser, steps=64)

x1 = sampler.init((4, 3, 256, 256), device=device)
x0_uncond = sampler(x1)

to_pil_image(make_grid(postprocess(x0_uncond)))
100%|########################################| 64/64 [00:38<00:00,  1.66step/s]
../_images/d3fb962fe2ec8694371c325b07c351a9e3d5285f359e42b5e94384432395a96a.png

2.2. Measurement

image = requests.get("https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg", headers={"User-Agent": "Azula"}).content  # fmt: off
image = io.BytesIO(image)
image = Image.open(image).convert("RGB")
image = image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
image
../_images/2bc38707721e20eaf6adb9c2ffeaeea27bddefce873f53faeb373f0980f78b38.png
x = preprocess(to_tensor(image)).to(device)


def A(x):
    return torch.nn.functional.interpolate(x, (32, 32), mode="bicubic", antialias=True).flatten(-3)


def A_inv(y):
    return torch.nn.functional.interpolate(
        y.unflatten(-1, (3, 32, 32)), (256, 256), mode="nearest"
    )


sigma_y = 0.01

y = A(x.unsqueeze(0))
y = y + sigma_y * torch.randn_like(y)

to_pil_image(make_grid(postprocess(A_inv(y))))
../_images/0d1a0baee0ce0ceeb59639b9080bb61760325406bf8b0415804269bfbc8bffb3.png

2.3. Diffusion Posterior Sampling (DPS)

cond_sampler = DPSSampler(denoiser, y=y, A=A, steps=64)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [01:49<00:00,  1.71s/step]
../_images/d9425b92f00040d51c2779fa1acbe9bd9d2efade4ac4667b3fd178baeff88381.png

2.4. Pseudo-inverse Guided Diffusion Model (PGDM)

cond_sampler = PGDMSampler(denoiser, y=y, A=A, A_inv=A_inv, steps=64, eta=1.0)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [01:49<00:00,  1.72s/step]
../_images/10353060bd619ad3b793ebda0bdd1a11592c8da4884ffe2f1122ba42676f0a12.png

2.5. Diffusion Plug-and-Play Image Restoration (DiffPIR)

cond_denoiser = DiffPIRDenoiser(denoiser, y=y, A=A, var_y=sigma_y**2, iterations=1)
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:39<00:00,  1.63step/s]
../_images/c21ad99ce58a0932b44f2a65564adc4ca730764a670453f29f0b6679fe4b7b9c.png

2.6. Tweedie Moment Projected Diffusion (TMPD)

cond_denoiser = TMPDenoiser(denoiser, y=y, A=A, var_y=sigma_y**2)
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [03:00<00:00,  2.82s/step]
../_images/097d90690b054deb9388bd5fe3f704589a0c31385f87a3d83ccaecf3cc987cbf.png

2.7. Moment Matching Posterior Sampling (MMPS)

cond_denoiser = MMPSDenoiser(denoiser, y=y, A=A, cov_y=DiagonalCovariance(sigma_y**2), iterations=3)  # fmt: off
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [05:21<00:00,  5.02s/step]
../_images/6f99ef1ddc51cda349dfb6f626bd11bc2a9fad1710b7ad7eb335625460e87451.png

2.8. Jacobian-Free Posterior Sampling (JFPS)

cov_x = PreconditionedCovariance.from_data(x0_uncond)

cond_denoiser = JFPSDenoiser(denoiser, y=y, A=A, cov_y=DiagonalCovariance(sigma_y**2), cov_x=cov_x, iterations=11)  # fmt: off
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)

x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)

to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:40<00:00,  1.60step/s]
../_images/67827fa89324e907ce27e2f89d3416be6b66e1306421e4be27c5a4ed1d1a0d9d.png