3. Latent diffusion

This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Sana and Flux.

# !pip install diffusers transformers accelerate
import torch

from accelerate import cpu_offload
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid

from azula.guidance import CFGDenoiser
from azula.plugins import flux, sana, sd
from azula.sample import zABSampler  # Adams-Bashforth sampler

device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)

3.1. Stable Diffusion

denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder, textencoder = (
    denoiser.to(device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading pipeline components...: 100%|██████████| 7/7 [00:03<00:00,  1.83it/s]
with torch.no_grad():
    prompt = textencoder("an astronaut riding a horse in space")
    null = textencoder("")

We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.

sampler = zABSampler(denoiser, steps=16)

z1 = sampler.init((2, 4, 64, 64), device=device)  # B C H W
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:02<00:00,  6.57step/s]
../_images/b082fb506c976d1b65ec6828636f0343260cd4476375eac9a845f7776d7b3de1.png

Now with classifier-free guidance, which should boost prompt fidelity and image quality.

sampler = zABSampler(CFGDenoiser(denoiser), steps=16)

z1 = sampler.init((2, 4, 64, 64), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:02<00:00,  7.78step/s]
../_images/47b0d126683893de4c336d19f295478ba0365856159425c2c7d9699e96bc4765.png
del denoiser, autoencoder, textencoder

3.2. Sana

denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder, textencoder = (
    denoiser.to(device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading pipeline components...:  20%|██        | 1/5 [00:01<00:07,  1.81s/it]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  7.14it/s]
Loading pipeline components...: 100%|██████████| 5/5 [00:08<00:00,  1.65s/it]
with torch.no_grad():
    prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
    null = textencoder("")
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)

z1 = sampler.init((2, 32, 16, 16), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:01<00:00,  8.98step/s]
../_images/72b2f8e58fe17a4cf8ed6963be5825848018564e4084cdd10f1384e8544a3408.png
del denoiser, autoencoder, textencoder

3.3. Flux

denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder, textencoder = (
    cpu_offload(denoiser, device),
    cpu_offload(autoencoder, device),
    cpu_offload(textencoder, device),
)
Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 13.49it/s]
Loading pipeline components...:  43%|████▎     | 3/7 [00:00<00:00,  4.50it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Loading checkpoint shards:  33%|███▎      | 1/3 [00:14<00:29, 14.91s/it]
Loading checkpoint shards:  67%|██████▋   | 2/3 [00:29<00:14, 14.66s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:35<00:00, 11.72s/it]
Loading pipeline components...: 100%|██████████| 7/7 [00:36<00:00,  5.20s/it]
with torch.no_grad():
    prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = zABSampler(denoiser, steps=16)

z1 = sampler.init((2, 32, 32, 64), device=device)  # B H W C
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [03:13<00:00, 12.08s/step]
../_images/a98c6449061f7b16e169a40f5605ee4b47e446583baf19260bf41eeff98a72d4.png