Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import gradio as gr | |
import random | |
import torch | |
from collections import defaultdict | |
from diffusers import DiffusionPipeline | |
from functools import partial | |
from itertools import zip_longest | |
from typing import List | |
from PIL import Image | |
SELECT_LABEL = "Select as seed" | |
MODEL_ID = "CompVis/ldm-text2im-large-256" | |
STEPS = 25 # while running on CPU | |
ETA = 0.3 | |
GUIDANCE_SCALE = 6 | |
ldm = DiffusionPipeline.from_pretrained(MODEL_ID) | |
import torch | |
print(f"cuda: {torch.cuda.is_available()}") | |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: | |
state = gr.Variable({ | |
'selected': -1, | |
'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] | |
}) | |
def infer_seeded_image(prompt, seed): | |
print(f"Prompt: {prompt}, seed: {seed}") | |
images, _ = infer_grid(prompt, n=1, seeds=[seed]) | |
return images[0] | |
def infer_grid(prompt, n=6, seeds=[]): | |
# Unfortunately we have to iterate instead of requesting all images at once, | |
# because we have no way to get the intermediate generation seeds. | |
result = defaultdict(list) | |
for _, seed in zip_longest(range(n), seeds, fillvalue=None): | |
seed = random.randint(0, 2**32 - 1) if seed is None else seed | |
_ = torch.manual_seed(seed) | |
with torch.autocast("cuda"): | |
images = ldm( | |
[prompt], | |
num_inference_steps=STEPS, | |
eta=ETA, | |
guidance_scale=GUIDANCE_SCALE | |
)["sample"] | |
result["images"].append(images[0]) | |
result["seeds"].append(seed) | |
return result["images"], result["seeds"] | |
def infer(prompt, state): | |
""" | |
Outputs: | |
- Grid images (list) | |
- Seeded Image (Image or None) | |
- Grid Box with updated visibility | |
- Seeded Box with updated visibility | |
""" | |
grid_images = [None] * 6 | |
image_with_seed = None | |
visible = (False, False) | |
if (seed_index := state["selected"]) > -1: | |
seed = state["seeds"][seed_index] | |
image_with_seed = infer_seeded_image(prompt, seed) | |
visible = (False, True) | |
else: | |
grid_images, seeds = infer_grid(prompt) | |
state["seeds"] = seeds | |
visible = (True, False) | |
boxes = [gr.Box.update(visible=v) for v in visible] | |
return grid_images + [image_with_seed] + boxes + [state] | |
def update_state(selected_index: int, value, state): | |
if value == '': | |
others_value = None | |
else: | |
others_value = '' | |
state["selected"] = selected_index | |
others = gr.Radio.update(value=others_value) | |
return [others] * 5 + [state] | |
def clear_seed(state): | |
"""Update state of Radio buttons, grid, seeded_box""" | |
state["selected"] = -1 | |
return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] | |
def image_block(): | |
return gr.Image( | |
interactive=False, show_label=False | |
).style( | |
# border = (True, True, False, True), | |
rounded = (True, True, False, False), | |
) | |
def radio_block(): | |
radio = gr.Radio( | |
choices=[SELECT_LABEL], interactive=True, show_label=False, | |
).style( | |
# border = (False, True, True, True), | |
# rounded = (False, False, True, True) | |
container=False | |
) | |
return radio | |
gr.Markdown( | |
""" | |
<h1><center>Latent Diffusion Demo</center></h1> | |
<p>Type anything to generate a few images that represent your prompt. | |
Select one of the results to use as a <b>seed</b> for the next generation: | |
you can try variations of your prompt starting from the same state and see how it changes. | |
For example, <i>Labrador in the style of Vermeer</i> could be tweaked to | |
<i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>. | |
If your prompts are similar, the tweaked result should also have a similar structure | |
but different details or style.</p> | |
""" | |
) | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
text = gr.Textbox( | |
label="Enter your prompt", show_label=False, max_lines=1 | |
).style( | |
border=(True, False, True, True), | |
# margin=False, | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Run").style( | |
margin=False, | |
rounded=(False, True, True, False), | |
) | |
## Can we create a Component with these, so it can participate as an output? | |
with (grid := gr.Box()): | |
with gr.Row(): | |
with gr.Box().style(border=None): | |
image1 = image_block() | |
select1 = radio_block() | |
with gr.Box().style(border=None): | |
image2 = image_block() | |
select2 = radio_block() | |
with gr.Box().style(border=None): | |
image3 = image_block() | |
select3 = radio_block() | |
with gr.Row(): | |
with gr.Box().style(border=None): | |
image4 = image_block() | |
select4 = radio_block() | |
with gr.Box().style(border=None): | |
image5 = image_block() | |
select5 = radio_block() | |
with gr.Box().style(border=None): | |
image6 = image_block() | |
select6 = radio_block() | |
images = [image1, image2, image3, image4, image5, image6] | |
selectors = [select1, select2, select3, select4, select5, select6] | |
for i, radio in enumerate(selectors): | |
others = list(filter(lambda s: s != radio, selectors)) | |
radio.change( | |
partial(update_state, i), | |
inputs=[radio, state], | |
outputs=others + [state] | |
) | |
with (seeded_box := gr.Box()): | |
seeded_image = image_block() | |
clear_seed_button = gr.Button("Return to Grid") | |
seeded_box.visible = False | |
clear_seed_button.click( | |
clear_seed, | |
inputs=[state], | |
outputs=selectors + [grid, seeded_box] + [state] | |
) | |
all_images = images + [seeded_image] | |
boxes = [grid, seeded_box] | |
infer_outputs = all_images + boxes + [state] | |
text.submit( | |
infer, | |
inputs=[text, state], | |
outputs=infer_outputs | |
) | |
btn.click( | |
infer, | |
inputs=[text, state], | |
outputs=infer_outputs | |
) | |
demo.launch(enable_queue=True) |