2. Guidance¶
This tutorial demonstrates how to perform guidance with a pre-trained diffusion model.
import io
import requests
import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.utils import make_grid
from azula.guidance import (
DiffPIRDenoiser,
DPSSampler,
JFPSDenoiser,
MMPSDenoiser,
PGDMSampler,
TMPDenoiser,
)
from azula.linalg.covariance import IsotropicCovariance, KroneckerCovariance
from azula.plugins import adm
from azula.sample import DDIMSampler
device = "cuda"
_ = torch.manual_seed(42)
def preprocess(x):
return 2 * x - 1
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
2.1. Pre-trained diffusion model¶
denoiser = adm.load_model("imagenet_256x256").to(device)
denoiser = denoiser.requires_grad_(False) # reduce memory overhead
Loading from /home/rozetf/.cache/azula/hub/https.openaipublic.blob.core.windows.net.diffusion.jul.2021.256x256_diffusion_uncond.pt
sampler = DDIMSampler(denoiser, steps=64)
x1 = sampler.init((4, 3, 256, 256), device=device)
x0_uncond = sampler(x1)
to_pil_image(make_grid(postprocess(x0_uncond)))
100%|########################################| 64/64 [00:05<00:00, 12.08step/s]
2.2. Measurement¶
image = requests.get("https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg", headers={"User-Agent": "Azula"}).content # fmt: off
image = io.BytesIO(image)
image = Image.open(image).convert("RGB")
image = image.crop((0, 0, min(image.size), min(image.size))).resize((256, 256))
image
x = preprocess(to_tensor(image)).to(device)
def A(x):
return torch.nn.functional.interpolate(x, (32, 32), mode="bicubic", antialias=True).flatten(-3)
def A_inv(y):
return torch.nn.functional.interpolate(
y.unflatten(-1, (3, 32, 32)), (256, 256), mode="nearest"
)
sigma_y = 0.01
y = A(x.unsqueeze(0))
y = y + sigma_y * torch.randn_like(y)
to_pil_image(make_grid(postprocess(A_inv(y))))
2.3. Diffusion Posterior Sampling (DPS)¶
cond_sampler = DPSSampler(denoiser, y=y, A=A, steps=64)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:14<00:00, 4.41step/s]
2.4. Pseudo-inverse Guided Diffusion Model (PGDM)¶
cond_sampler = PGDMSampler(denoiser, y=y, A=A, A_inv=A_inv, steps=64, eta=1.0)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:13<00:00, 4.60step/s]
2.5. Diffusion Plug-and-Play Image Restoration (DiffPIR)¶
cond_denoiser = DiffPIRDenoiser(denoiser, y=y, A=A, var_y=sigma_y**2, iterations=1)
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:05<00:00, 12.69step/s]
2.6. Tweedie Moment Projected Diffusion (TMPD)¶
cond_denoiser = TMPDenoiser(denoiser, y=y, A=A, var_y=sigma_y**2)
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:22<00:00, 2.81step/s]
2.7. Moment Matching Posterior Sampling (MMPS)¶
cond_denoiser = MMPSDenoiser(denoiser, y=y, A=A, cov_y=IsotropicCovariance(sigma_y**2), iterations=3) # fmt: off
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:40<00:00, 1.57step/s]
2.8. Jacobian-Free Posterior Sampling (JFPS)¶
cov_x = KroneckerCovariance.from_data(x0_uncond)
cond_denoiser = JFPSDenoiser(denoiser, y=y, A=A, cov_y=IsotropicCovariance(sigma_y**2), cov_x=cov_x, iterations=11) # fmt: off
cond_sampler = DDIMSampler(cond_denoiser, steps=64, eta=1.0)
x1 = cond_sampler.init((4, 3, 256, 256), device=device)
x0 = cond_sampler(x1)
to_pil_image(make_grid(postprocess(x0)))
100%|########################################| 64/64 [00:05<00:00, 12.03step/s]