3. Latent diffusion

This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Flux and Sana.

# !pip install diffusers transformers accelerate
import torch

from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid

from azula.guidance import CFGDenoiser
from azula.nn.utils import cpu_offload
from azula.plugins import sd, flux, sana
from azula.sample import ABSampler  # Adams-Bashforth sampler

device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
    return torch.clip((x + 1) / 2, min=0, max=1)

3.1. Stable Diffusion

denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
    prompt = textencoder("an astronaut riding a horse in space")
    null = textencoder("")

We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.

sampler = ABSampler(denoiser, steps=16).to(device)

z1 = sampler.init((2, 4, 64, 64))  # B C H W
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
../_images/d03c812d397c1524976d8a2b539d311827e51261f5af0746a4f2b2fd304c356a.png

Now with classifier-free guidance, which should boost prompt fidelity and image quality.

sampler = ABSampler(CFGDenoiser(denoiser), steps=16).to(device)

z1 = sampler.init((2, 4, 64, 64))
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
../_images/5a44465382cd309f01153cea6d4bf6c272bcc475ce262101d2a34da351257362.png
del denoiser, autoencoder, textencoder

3.2. Flux

denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
    prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = ABSampler(denoiser, steps=16).to(device)

z1 = sampler.init((2, 32, 32, 64))  # B H W C
z0 = sampler(z1, **prompt)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
../_images/efaf6c2f000ba264187baea0bfca6eda9570c889df1862366c72a408cac27c62.png
del denoiser, autoencoder, textencoder

3.3. Sana

denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
    prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
    null = textencoder("")
sampler = ABSampler(CFGDenoiser(denoiser), steps=16).to(device)

z1 = sampler.init((2, 32, 16, 16))
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)

with torch.no_grad():
    x = autoencoder.decode(z0)

to_pil_image(make_grid(postprocess(x)))
../_images/1eb4039e8ca93f11e8d53c58b12da7bab5e6d7fe07e30246b4feb0691162f993.png