3. Latent diffusion¶
This tutorial demonstrates how to generate images with latent diffusion models such as Stable Diffusion, Flux and Sana.
# !pip install diffusers transformers accelerate
import torch
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from azula.guidance import CFGDenoiser
from azula.nn.utils import cpu_offload
from azula.plugins import sd, flux, sana
from azula.sample import ABSampler # Adams-Bashforth sampler
device = "cuda"
_ = torch.manual_seed(42)
def postprocess(x):
return torch.clip((x + 1) / 2, min=0, max=1)
3.1. Stable Diffusion¶
denoiser, autoencoder, textencoder = sd.load_model("sd_1.5")
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
prompt = textencoder("an astronaut riding a horse in space")
null = textencoder("")
We first try without classifier-free guidance (CFG). The resulting images do not follow the prompt well and have poor composition.
sampler = ABSampler(denoiser, steps=16).to(device)
z1 = sampler.init((2, 4, 64, 64)) # B C H W
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
Now with classifier-free guidance, which should boost prompt fidelity and image quality.
sampler = ABSampler(CFGDenoiser(denoiser), steps=16).to(device)
z1 = sampler.init((2, 4, 64, 64))
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
del denoiser, autoencoder, textencoder
3.2. Flux¶
denoiser, autoencoder, textencoder = flux.load_model()
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
prompt = textencoder("a forest with a big warning sign that says 'Flux'")
sampler = ABSampler(denoiser, steps=16).to(device)
z1 = sampler.init((2, 32, 32, 64)) # B H W C
z0 = sampler(z1, **prompt)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))
del denoiser, autoencoder, textencoder
3.3. Sana¶
denoiser, autoencoder, textencoder = sana.load_model("sana_1.6b_512")
denoiser, autoencoder = denoiser.to(device), autoencoder.to(device)
with torch.no_grad(), cpu_offload(textencoder, device):
prompt = textencoder("a cyberpunk cat with a neon sign that says 'Sana'")
null = textencoder("")
sampler = ABSampler(CFGDenoiser(denoiser), steps=16).to(device)
z1 = sampler.init((2, 32, 16, 16))
z0 = sampler(z1, positive=prompt, negative=null, guidance=3.0)
with torch.no_grad():
x = autoencoder.decode(z0)
to_pil_image(make_grid(postprocess(x)))