consistency-models / README.md
kisnikser's picture
Update README.md
b3df0c4 verified
metadata
library_name: diffusers
license: mit
language:
  - en
base_model:
  - sd-legacy/stable-diffusion-v1-5
datasets:
  - laion/laion-coco
pipeline_tag: text-to-image

Consistency Models (YSDA CV Week 2024)

This repository contains the weights of the models trained as the final task of the YSDA CV Week 2024.

Consistency Models were trained based on the Stable Diffusion 1.5 (SD 1.5) checkpoint: "sd-legacy/stable-diffusion-v1-5".

The training consisted of additional LoRA modules of rank 64 on top of some of the layers of the main model. We have considered three different variants of Consistency Models:

  1. Consistency Training
  2. Consistency Distillation
  3. Multi-boundary Consistency Distillation

We trained each of them on the 5k subset from COCO dataset. For each of the models, the weights of the corresponding LoRA adapter have been preserved in the usual PEFT format.

You can reproduce the generation results for 3) Multi-boundary Consistency Distillation as follows:

%matplotlib inline
import matplotlib.pyplot as plt

import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from peft import PeftModel

def visualize_images(images):
    assert len(images) == 4
    plt.figure(figsize=(12, 3))
    for i, image in enumerate(images):
        plt.subplot(1, 4, i+1)
        plt.imshow(image)
        plt.axis('off')

    plt.subplots_adjust(wspace=-0.01, hspace=-0.01)

pipe = StableDiffusionPipeline.from_pretrained("sd-legacy/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()

loaded_cm_unet = PeftModel.from_pretrained(
    pipe.unet.to(torch.float32),
    "kisnikser/consistency-models",
    subfolder="multi-cd",
    adapter_name="multi-cd",
)

pipe.unet = loaded_cm_unet.eval().to(torch.float16)

validation_prompts = [
    "A sad puppy with large eyes",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
    "A girl with pale blue hair and a cami tank top",
    "A lighthouse in a giant wave, origami style",
    "belle epoque, christmas, red house in the forest, photo realistic, 8k",
    "A small cactus with a happy face in the Sahara desert",
    "Green commercial building with refrigerator and refrigeration units outside",
]

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)
    images = pipe(
        prompt=prompt,
        guidance_scale=1.0,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4
    ).images
    visualize_images(images)

image/png image/png image/png image/png image/png image/png image/png image/png