3. Latent diffusion¶
This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Sana and Flux.
# !pip install diffusers transformers accelerate
import torch
import tqdm.auto
tqdm.auto.tqdm = tqdm.asyncio.tqdm # disable HTML progress bars
from accelerate import cpu_offload
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from azula.guidance import CFGDenoiser
from azula.plugins import flux, sana, sd
from azula.sample import zABSampler # Adams-Bashforth sampler
device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
3.1. Stable Diffusion¶
denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder, textencoder = (
denoiser.to(device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading weights: 100%|██████████| 396/396 [00:00<00:00, 737.48it/s, Materializing param=visual_projection.weight]
Loading weights: 100%|██████████| 196/196 [00:00<00:00, 789.85it/s, Materializing param=text_model.final_layer_norm.weight]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00, 6.20it/s]
with torch.no_grad():
prompt = textencoder("an astronaut riding a horse in space")
null = textencoder("")
We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.
sampler = zABSampler(denoiser, steps=16)
z1 = sampler.init((2, 4, 64, 64), device=device) # B C H W
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 21.44step/s]
Now with classifier-free guidance, which should boost prompt fidelity and image quality.
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)
z1 = sampler.init((2, 4, 64, 64), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 25.41step/s]
del denoiser, autoencoder, textencoder
3.2. Sana¶
denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder, textencoder = (
denoiser.to(device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading weights: 100%|██████████| 288/288 [00:00<00:00, 749.30it/s, Materializing param=norm.weight]
Loading pipeline components...: 100%|██████████| 5/5 [00:03<00:00, 1.58it/s]
with torch.no_grad():
prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
null = textencoder("")
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)
z1 = sampler.init((2, 32, 16, 16), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:00<00:00, 18.23step/s]
del denoiser, autoencoder, textencoder
3.3. Flux¶
denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder, textencoder = (
denoiser.to(device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 36.87it/s]
Loading weights: 100%|██████████| 219/219 [00:00<00:00, 499.31it/s, Materializing param=shared.weight]
Loading weights: 100%|██████████| 196/196 [00:00<00:00, 788.41it/s, Materializing param=text_model.final_layer_norm.weight]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00, 6.28it/s]
with torch.no_grad():
prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = zABSampler(denoiser, steps=16)
z1 = sampler.init((2, 32, 32, 64), device=device) # B H W C
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:01<00:00, 8.06step/s]