3. Latent diffusion¶
This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Sana and Flux.
# !pip install diffusers transformers accelerate
import torch
from accelerate import cpu_offload
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from azula.guidance import CFGDenoiser
from azula.plugins import flux, sana, sd
from azula.sample import zABSampler # Adams-Bashforth sampler
device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
3.1. Stable Diffusion¶
denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder, textencoder = (
denoiser.to(device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading pipeline components...: 100%|██████████| 7/7 [00:03<00:00, 1.83it/s]
with torch.no_grad():
prompt = textencoder("an astronaut riding a horse in space")
null = textencoder("")
We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.
sampler = zABSampler(denoiser, steps=16)
z1 = sampler.init((2, 4, 64, 64), device=device) # B C H W
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:02<00:00, 6.57step/s]
Now with classifier-free guidance, which should boost prompt fidelity and image quality.
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)
z1 = sampler.init((2, 4, 64, 64), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:02<00:00, 7.78step/s]
del denoiser, autoencoder, textencoder
3.2. Sana¶
denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder, textencoder = (
denoiser.to(device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading pipeline components...: 20%|██ | 1/5 [00:01<00:07, 1.81s/it]
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 7.14it/s]
Loading pipeline components...: 100%|██████████| 5/5 [00:08<00:00, 1.65s/it]
with torch.no_grad():
prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
null = textencoder("")
sampler = zABSampler(CFGDenoiser(denoiser), steps=16)
z1 = sampler.init((2, 32, 16, 16), device=device)
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [00:01<00:00, 8.98step/s]
del denoiser, autoencoder, textencoder
3.3. Flux¶
denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder, textencoder = (
cpu_offload(denoiser, device),
cpu_offload(autoencoder, device),
cpu_offload(textencoder, device),
)
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 13.49it/s]
Loading pipeline components...: 43%|████▎ | 3/7 [00:00<00:00, 4.50it/s]
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
Loading checkpoint shards: 33%|███▎ | 1/3 [00:14<00:29, 14.91s/it]
Loading checkpoint shards: 67%|██████▋ | 2/3 [00:29<00:14, 14.66s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:35<00:00, 11.72s/it]
Loading pipeline components...: 100%|██████████| 7/7 [00:36<00:00, 5.20s/it]
with torch.no_grad():
prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = zABSampler(denoiser, steps=16)
z1 = sampler.init((2, 32, 32, 64), device=device) # B H W C
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
100%|########################################| 16/16 [03:13<00:00, 12.08s/step]