#!/usr/bin/env python # coding: utf-8 import gradio as gr import random import torch from collections import defaultdict from diffusers import DiffusionPipeline from functools import partial from itertools import zip_longest from typing import List from PIL import Image SELECT_LABEL = "Select as seed" MODEL_ID = "CompVis/ldm-text2im-large-256" STEPS = 25 # while running on CPU ETA = 0.3 GUIDANCE_SCALE = 6 ldm = DiffusionPipeline.from_pretrained(MODEL_ID) import torch print(f"cuda: {torch.cuda.is_available()}") with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: state = gr.Variable({ 'selected': -1, 'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] }) def infer_seeded_image(prompt, seed): print(f"Prompt: {prompt}, seed: {seed}") images, _ = infer_grid(prompt, n=1, seeds=[seed]) return images[0] def infer_grid(prompt, n=6, seeds=[]): # Unfortunately we have to iterate instead of requesting all images at once, # because we have no way to get the intermediate generation seeds. result = defaultdict(list) for _, seed in zip_longest(range(n), seeds, fillvalue=None): seed = random.randint(0, 2**32 - 1) if seed is None else seed _ = torch.manual_seed(seed) with torch.autocast("cuda"): images = ldm( [prompt], num_inference_steps=STEPS, eta=ETA, guidance_scale=GUIDANCE_SCALE )["sample"] result["images"].append(images[0]) result["seeds"].append(seed) return result["images"], result["seeds"] def infer(prompt, state): """ Outputs: - Grid images (list) - Seeded Image (Image or None) - Grid Box with updated visibility - Seeded Box with updated visibility """ grid_images = [None] * 6 image_with_seed = None visible = (False, False) if (seed_index := state["selected"]) > -1: seed = state["seeds"][seed_index] image_with_seed = infer_seeded_image(prompt, seed) visible = (False, True) else: grid_images, seeds = infer_grid(prompt) state["seeds"] = seeds visible = (True, False) boxes = [gr.Box.update(visible=v) for v in visible] return grid_images + [image_with_seed] + boxes + [state] def update_state(selected_index: int, value, state): if value == '': others_value = None else: others_value = '' state["selected"] = selected_index others = gr.Radio.update(value=others_value) return [others] * 5 + [state] def clear_seed(state): """Update state of Radio buttons, grid, seeded_box""" state["selected"] = -1 return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] def image_block(): return gr.Image( interactive=False, show_label=False ).style( # border = (True, True, False, True), rounded = (True, True, False, False), ) def radio_block(): radio = gr.Radio( choices=[SELECT_LABEL], interactive=True, show_label=False, ).style( # border = (False, True, True, True), # rounded = (False, False, True, True) container=False ) return radio gr.Markdown( """
Type anything to generate a few images that represent your prompt. Select one of the results to use as a seed for the next generation: you can try variations of your prompt starting from the same state and see how it changes. For example, Labrador in the style of Vermeer could be tweaked to Labrador in the style of Picasso or Lynx in the style of Van Gogh. If your prompts are similar, the tweaked result should also have a similar structure but different details or style.
""" ) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1 ).style( border=(True, False, True, True), # margin=False, rounded=(True, False, False, True), container=False, ) btn = gr.Button("Run").style( margin=False, rounded=(False, True, True, False), ) ## Can we create a Component with these, so it can participate as an output? with (grid := gr.Box()): with gr.Row(): with gr.Box().style(border=None): image1 = image_block() select1 = radio_block() with gr.Box().style(border=None): image2 = image_block() select2 = radio_block() with gr.Box().style(border=None): image3 = image_block() select3 = radio_block() with gr.Row(): with gr.Box().style(border=None): image4 = image_block() select4 = radio_block() with gr.Box().style(border=None): image5 = image_block() select5 = radio_block() with gr.Box().style(border=None): image6 = image_block() select6 = radio_block() images = [image1, image2, image3, image4, image5, image6] selectors = [select1, select2, select3, select4, select5, select6] for i, radio in enumerate(selectors): others = list(filter(lambda s: s != radio, selectors)) radio.change( partial(update_state, i), inputs=[radio, state], outputs=others + [state] ) with (seeded_box := gr.Box()): seeded_image = image_block() clear_seed_button = gr.Button("Return to Grid") seeded_box.visible = False clear_seed_button.click( clear_seed, inputs=[state], outputs=selectors + [grid, seeded_box] + [state] ) all_images = images + [seeded_image] boxes = [grid, seeded_box] infer_outputs = all_images + boxes + [state] text.submit( infer, inputs=[text, state], outputs=infer_outputs ) btn.click( infer, inputs=[text, state], outputs=infer_outputs ) demo.launch(enable_queue=True)