3. Latent diffusion

This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Sana and Flux.

# !pip install diffusers transformers accelerate
import torch
import tqdm.auto

tqdm.auto.tqdm = tqdm.asyncio.tqdm  # disable HTML progress bars

from accelerate import cpu_offload
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid

from azula.guidance import CFGDenoiser
from azula.plugins import flux, sana, sd
from azula.sample import zABSampler  # Adams-Bashforth sampler

device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)

3.1. Stable Diffusion

denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder, textencoder = (
    denoiser.to(device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading weights: 100%|██████████| 396/396 [00:00<00:00, 737.48it/s, Materializing param=visual_projection.weight]
Loading weights: 100%|██████████| 196/196 [00:00<00:00, 789.85it/s, Materializing param=text_model.final_layer_norm.weight]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  6.20it/s]
with torch.no_grad():
    prompt = textencoder("an astronaut riding a horse in space")
    null = textencoder("")

We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.

sampler = zABSampler(denoiser, steps=16)

z1 = sampler.init((2, 4, 64, 64), device=device)  # B C H W
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 21.44step/s]
../_images/e0e68c66d46f1cdced6c4de85db8f99c2cceb72c3e3f0aee6c90d6caec1e0534.png

Now with classifier-free guidance, which should boost prompt fidelity and image quality.

sampler = zABSampler(CFGDenoiser(denoiser), steps=16)

z1 = sampler.init((2, 4, 64, 64), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 25.41step/s]
../_images/a4940d599352409aa76831f49efa46865f0c66455f7d6a0501fd17c7b0b53896.png
del denoiser, autoencoder, textencoder

3.2. Sana

denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder, textencoder = (
    denoiser.to(device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading weights: 100%|██████████| 288/288 [00:00<00:00, 749.30it/s, Materializing param=norm.weight]
Loading pipeline components...: 100%|██████████| 5/5 [00:03<00:00,  1.58it/s]
with torch.no_grad():
    prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
    null = textencoder("")
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)

z1 = sampler.init((2, 32, 16, 16), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 18.23step/s]
../_images/ba84148469f7033a69e4502c5b458c25a744ebae3128d80305dfa9cc6a954c7a.png
del denoiser, autoencoder, textencoder

3.3. Flux

denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder, textencoder = (
    denoiser.to(device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 36.87it/s]
Loading weights: 100%|██████████| 219/219 [00:00<00:00, 499.31it/s, Materializing param=shared.weight]
Loading weights: 100%|██████████| 196/196 [00:00<00:00, 788.41it/s, Materializing param=text_model.final_layer_norm.weight]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  6.28it/s]
with torch.no_grad():
    prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = zABSampler(denoiser, steps=16)

z1 = sampler.init((2, 32, 32, 64), device=device)  # B H W C
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:01<00:00,  8.06step/s]
../_images/a912405b977bb735c48b183aff3bc1467e8b7a1b3bd183b824d1e2831ec7a7a5.png