#!/usr/bin/env python from __future__ import annotations import os import gradio as gr import PIL.Image import torch from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline DESCRIPTION = """\ # Attend-and-Excite This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826). Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion. Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index). """ if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
" if torch.cuda.is_available(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "CompVis/stable-diffusion-v1-4" ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id) ax_pipe.to(device) sd_pipe = StableDiffusionPipeline.from_pretrained(model_id) sd_pipe.to(device) def get_token_table(prompt: str) -> list[tuple[int, str]]: tokens = [ax_pipe.tokenizer.decode(t) for t in ax_pipe.tokenizer(prompt)["input_ids"]] tokens = tokens[1:-1] return list(enumerate(tokens, start=1)) def run( prompt: str, indices_to_alter_str: str, seed: int = 0, apply_attend_and_excite: bool = True, num_steps: int = 50, guidance_scale: float = 7.5, scale_factor: int = 20, thresholds: dict[int, float] = { 10: 0.5, 20: 0.8, }, max_iter_to_alter: int = 25, ) -> PIL.Image.Image: generator = torch.Generator(device=device).manual_seed(seed) if apply_attend_and_excite: try: token_indices = list(map(int, indices_to_alter_str.split(","))) except Exception: raise ValueError("Invalid token indices.") out = ax_pipe( prompt=prompt, token_indices=token_indices, guidance_scale=guidance_scale, generator=generator, num_inference_steps=num_steps, max_iter_to_alter=max_iter_to_alter, thresholds=thresholds, scale_factor=scale_factor, ) else: out = sd_pipe( prompt=prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=num_steps, ) return out.images[0] def process_example( prompt: str, indices_to_alter_str: str, seed: int, apply_attend_and_excite: bool, ) -> tuple[list[tuple[int, str]], PIL.Image.Image]: token_table = get_token_table(prompt) result = run( prompt=prompt, indices_to_alter_str=indices_to_alter_str, seed=seed, apply_attend_and_excite=apply_attend_and_excite, ) return token_table, result with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): prompt = gr.Text( label="Prompt", max_lines=1, placeholder="A pod of dolphins leaping out of the water in an ocean with a ship on the background", ) with gr.Accordion(label="Check token indices", open=False): show_token_indices_button = gr.Button("Show token indices") token_indices_table = gr.Dataframe(label="Token indices", headers=["Index", "Token"], col_count=2) token_indices_str = gr.Text( label="Token indices (a comma-separated list indices of the tokens you wish to alter)", max_lines=1, placeholder="4,16", ) seed = gr.Slider( label="Seed", minimum=0, maximum=100000, step=1, value=0, ) apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True) num_steps = gr.Slider( label="Number of steps", minimum=0, maximum=100, step=1, value=50, ) guidance_scale = gr.Slider( label="CFG scale", minimum=0, maximum=50, step=0.1, value=7.5, ) run_button = gr.Button("Generate") with gr.Column(): result = gr.Image(label="Result") with gr.Row(): examples = [ [ "A mouse and a red car", "2,6", 2098, True, ], [ "A mouse and a red car", "2,6", 2098, False, ], [ "A horse and a dog", "2,5", 123, True, ], [ "A horse and a dog", "2,5", 123, False, ], [ "A painting of an elephant with glasses", "5,7", 123, True, ], [ "A painting of an elephant with glasses", "5,7", 123, False, ], [ "A playful kitten chasing a butterfly in a wildflower meadow", "3,6,10", 123, True, ], [ "A playful kitten chasing a butterfly in a wildflower meadow", "3,6,10", 123, False, ], [ "A grizzly bear catching a salmon in a crystal clear river surrounded by a forest", "2,6,15", 123, True, ], [ "A grizzly bear catching a salmon in a crystal clear river surrounded by a forest", "2,6,15", 123, False, ], [ "A pod of dolphins leaping out of the water in an ocean with a ship on the background", "4,16", 123, True, ], [ "A pod of dolphins leaping out of the water in an ocean with a ship on the background", "4,16", 123, False, ], ] gr.Examples( examples=examples, inputs=[ prompt, token_indices_str, seed, apply_attend_and_excite, ], outputs=[ token_indices_table, result, ], fn=process_example, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", examples_per_page=20, ) show_token_indices_button.click( fn=get_token_table, inputs=prompt, outputs=token_indices_table, queue=False, api_name="get-token-table", ) inputs = [ prompt, token_indices_str, seed, apply_attend_and_excite, num_steps, guidance_scale, ] prompt.submit( fn=get_token_table, inputs=prompt, outputs=token_indices_table, queue=False, api_name=False, ).then( fn=run, inputs=inputs, outputs=result, api_name=False, ) token_indices_str.submit( fn=get_token_table, inputs=prompt, outputs=token_indices_table, queue=False, api_name=False, ).then( fn=run, inputs=inputs, outputs=result, api_name=False, ) run_button.click( fn=get_token_table, inputs=prompt, outputs=token_indices_table, queue=False, api_name=False, ).then( fn=run, inputs=inputs, outputs=result, api_name="run", ) if __name__ == "__main__": demo.queue(max_size=10).launch()