license: mit
library_name: diffusers
Stage-A-ft-HQ
stage-a-ft-hq
is a version of Würstchen's Stage A that was finetuned to have slightly-nicer-looking textures.
stage-a-ft-hq
works with any Würstchen-derived model (including Stable Cascade).
Example comparison
Explanation
Image generators like Würstchen and Stable Cascade create images via a multi-stage process. Stage A is the ultimate stage, responsible for rendering out full-resolution, human-interpretable images (based on the output from prior stages).
The original Stage A tends to render slightly-smoothed-out images with a distinctive noise pattern on top.
stage-a-ft-hq
was finetuned briefly on a high-quality dataset in order to reduce these artifacts.
Suggested Settings
To generate highly detailed images, you probably want to use stage-a-ft-hq
(which improves very fine detail) in combination with a large Stage B step count (which improves mid-level detail).
🧨 Diffusers Usage
⚠️ As of 2024-02-17, Stable Cascade's PR is still under review. I've only tested Stable Cascade with this particular version of the PR:
pip install --upgrade --force-reinstall https://github.com/kashif/diffusers/archive/a3dc21385b7386beb3dab3a9845962ede6765887.zip
import torch
device = "cuda"
# Load the Stage-A-ft-HQ model
from diffusers.pipelines.wuerstchen import PaellaVQModel
stage_a_ft_hq = PaellaVQModel.from_pretrained("madebyollin/stage-a-ft-hq", torch_dtype=torch.float16).to(device)
# Load the normal Stable Cascade pipeline
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
num_images_per_prompt = 1
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
# Swap in the Stage-A-ft-HQ model
decoder.vqgan = stage_a_ft_hq
prompt = "Photograph of Seattle streets on a snowy winter morning"
negative_prompt = ""
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=20
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=20
).images
display(decoder_output[0])