Edit model card

APTP: Adaptive Prompt-Tailored Pruning of T2I Diffusion Models

arXiv Github

The implementation of the paper "Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models"

Abstract

Text-to-image (T2I) diffusion models have demonstrated impressive image generation capabilities. Still, their computational intensity prohibits resource-constrained organizations from deploying T2I models after fine-tuning them on their internal target data. While pruning techniques offer a potential solution to reduce the computational burden of T2I models, static pruning methods use the same pruned model for all input prompts, overlooking the varying capacity requirements of different prompts. Dynamic pruning addresses this issue by utilizing a separate sub-network for each prompt, but it prevents batch parallelism on GPUs. To overcome these limitations, we introduce Adaptive Prompt-Tailored Pruning (APTP), a novel prompt-based pruning method designed for T2I diffusion models. Central to our approach is a prompt router model, which learns to determine the required capacity for an input text prompt and routes it to an architecture code, given a total desired compute budget for prompts. Each architecture code represents a specialized model tailored to the prompts assigned to it, and the number of codes is a hyperparameter. We train the prompt router and architecture codes using contrastive learning, ensuring that similar prompts are mapped to nearby codes. Further, we employ optimal transport to prevent the codes from collapsing into a single one. We demonstrate APTP's effectiveness by pruning Stable Diffusion (SD) V2.1 using CC3M and COCO as target datasets. APTP outperforms the single-model pruning baselines in terms of FID, CLIP, and CMMD scores. Our analysis of the clusters learned by APTP reveals they are semantically meaningful. We also show that APTP can automatically discover previously empirically found challenging prompts for SD, e.g., prompts for generating text images, assigning them to higher capacity codes.

APTP Overview

APTP: We prune a text-to-image diffusion model like Stable Diffusion (left) into a mixture of efficient experts (right) in a prompt-based manner. Our prompt router routes distinct types of prompts to different experts, allowing experts' architectures to be separately specialized by removing layers or channels.

APTP Pruning Scheme

APTP pruning scheme. We train the prompt router and the set of architecture codes to prune a T2I diffusion model into a mixture of experts. The prompt router consists of three modules. We use a Sentence Transformer as the prompt encoder to encode the input prompt into a representation z. Then, the architecture predictor transforms z into the architecture embedding e that has the same dimensionality as architecture codes. Finally, the router routes the embedding e into an architecture code a(i). We use optimal transport to evenly distribute the prompts in a training batch among the architecture codes. The architecture code a(i) = (u(i), v(i)) determines pruning the model’s width and depth. We train the prompt router’s parameters and architecture codes in an end-to-end manner using the denoising objective of the pruned model LDDPM, distillation loss between the pruned and original models Ldistill, average resource usage for the samples in the batch R, and contrastive objective Lcont, encouraging embeddings e preserving semantic similarity of the representations z.

Model Description

  • Developed by: UMD Efficiency Group
  • Model type: Text-to-Image Diffusion Model
  • Model Description: APTP is a pruning scheme for text-to-image diffusion models like Stable Diffusion, resulting in a mixture of efficient experts specialized for different prompt types.

License

APTP is released under the MIT License. Please see the LICENSE file for details.

Training Dataset

We used Conceptual Captions and MS-COCO 2014 datasets for training the models. Details for downloading and preparing these datasets are provided in the Github Repository.

File Structure

APTP
β”œβ”€β”€ APTP-Base-CC3M
β”‚   β”œβ”€β”€ arch0  
β”‚   β”œβ”€β”€ ...
β”‚   └── arch15
β”œβ”€β”€ APTP-Small-CC3M
β”‚   β”œβ”€β”€ arch0
β”‚   β”œβ”€β”€ ...
β”‚   └── arch7
β”œβ”€β”€ APTP-Base-COCO
β”‚   β”œβ”€β”€ arch0
β”‚   β”œβ”€β”€ ...
β”‚   └── arch7
└── APTP-Small-COCO
    β”œβ”€β”€ arch0
    β”œβ”€β”€ ...
    └── arch7

Simple Inference Example

Make sure you follow the provided instructions to install pdm from source.

from diffusers import StableDiffusionPipeline, PNDMScheduler
from pdm.models import HyperStructure, StructureVectorQuantizer, UNet2DConditionModelPruned
from pdm.utils.data_utils import get_mpnet_embeddings
from transformers import AutoTokenizer, AutoModel
import torch

prompt_encoder_model_name_or_path = "sentence-transformers/all-mpnet-base-v2"
aptp_model_name_or_path = f"rezashkv/APTP"
aptp_variant = "APTP-Base-CC3M"
sd_model_name_or_path = "stabilityai/stable-diffusion-2-1"

prompt_encoder = AutoModel.from_pretrained(prompt_encoder_model_name_or_path)
prompt_encoder_tokenizer = AutoTokenizer.from_pretrained(prompt_encoder_model_name_or_path)

hyper_net = HyperStructure.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/hypernet")
quantizer = StructureVectorQuantizer.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/quantizer")

prompts = ["a woman on a white background looks down and away from the camera the a forlorn look on her face"]
prompt_embedding = get_mpnet_embeddings(prompts, prompt_encoder, prompt_encoder_tokenizer)

arch_embedding = hyper_net(prompt_embedding)
expert_id = quantizer.get_cosine_sim_min_encoding_indices(arch_embedding)[0].item()

unet = UNet2DConditionModelPruned.from_pretrained(aptp_model_name_or_path,
                                                  subfolder=f"{aptp_variant}/arch{expert_id}/checkpoint-30000/unet")
noise_scheduler = PNDMScheduler.from_pretrained(sd_model_name_or_path, subfolder="scheduler")

pipeline = StableDiffusionPipeline.from_pretrained(sd_model_name_or_path, unet=unet, scheduler=noise_scheduler)

pipeline.to('cuda')

generator = torch.Generator(device='cuda').manual_seed(43)

image = pipeline(
    prompt=prompts[0],
    guidance_scale=7.5,
    generator=generator,
    output_type='pil',
).images[0]

image.save("image.png")

Uses

This model is designed for academic and research purposes, specifically for exploring the efficiency of text-to-image diffusion models through prompt-based pruning. Potential applications include:

  1. Research: Researchers can use the model to study prompt-based pruning techniques and their impact on the performance and efficiency of text-to-image generation models.
  2. Education: Educators and students can use this model as a learning tool for understanding advanced concepts in neural network pruning, diffusion models, and prompt engineering.
  3. Benchmarking: The model can be used for benchmarking against other text-to-image generation models to assess the trade-offs between computational efficiency and output quality.

Safety

When using these models, it is important to consider the following safety and ethical guidelines:

  1. Content Generation: The model can generate a wide range of images based on text prompts. Users should ensure that the generated content adheres to ethical guidelines and does not produce harmful, offensive, or inappropriate images.
  2. Bias and Fairness: Like other AI models, APTP may exhibit biases present in the training data. Users should be aware of these potential biases and take steps to mitigate their impact, particularly when the model is used in sensitive or critical applications.
  3. Data Privacy: Ensure that any data used with the model complies with data privacy regulations. Avoid using personally identifiable information (PII) or sensitive data without proper consent.
  4. Responsible Use: Users are encouraged to use the model responsibly, considering the potential social and ethical implications of their work. This includes avoiding the generation of misleading or false information and respecting the rights and dignity of individuals depicted in generated images.

By adhering to these guidelines, users can help ensure the responsible and ethical use of the APTP model.

Contact

In case of any questions or issues, please contact the authors of the paper:

Downloads last month
0
Inference API
This model can be loaded on Inference API (serverless).