DongbaDreamer / app.py
initialneil's picture
- update space
761be79
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from model.omegaconf_utils import load_from_config
from model.dongba_dreamer import DongbaDreamer
##################################################
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU (free space, very slow, around 200 seconds)"
##################################################
config_files = ['config/dongba_dreamer.yaml']
config = load_from_config(config_files)
# dongba dreamer
dreamer = DongbaDreamer(config)
##################################################
def process_pipeline(image_topic, canvas_width, canvas_height, num_words):
dreamer.reset_logs()
rlt = dreamer.process_words(image_topic, canvas_width, canvas_height, num_words)
word_images = rlt['word_images']
image_prompt = rlt['image_prompt']
composition_image = rlt['composition_image']
sd_images = dreamer.process_sd(image_prompt, composition_image)
return word_images, sd_images
##################################################
CSS ="""
.contain { display: flex; flex-direction: column; }
#component-0 { height: 100%; }
#output_box { flex-grow: 1; overflow: auto; }
"""
block = gr.Blocks(css=CSS).queue()
with block:
with gr.Row():
gr.Markdown(f"## Dongba Dreamer Showcase, running on {power_device}")
with gr.Row():
with gr.Column(scale=1):
image_topic = gr.Textbox(label="Image Topic", value="日出江花红胜火")
with gr.Row():
canvas_width = gr.Number(label="Canvas Width (multiple of 512)", value=512)
canvas_height = gr.Number(label="Canvas Height (multiple of 512)", value=512)
num_words = gr.Number(label="Expected number of characters (0 = auto)", value=0)
button_pipeline = gr.Button(value="Launch Dongba Dreamer")
gallery_words = gr.Gallery(label='Dongba Words', show_label=False, elem_id="gallery", columns=4)
with gr.Column(scale=2, elem_id="output_box"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=2)
with gr.Row():
logbox = gr.Textbox(label='Log', value=dreamer.get_logs, every=1,
interactive=False)
inputs = [image_topic, canvas_width, canvas_height, num_words]
# button_words.click(fn=process_words, inputs=inputs, outputs=[gallery_words])
# button_sd.click(fn=process_sd, inputs=None, outputs=[result_gallery])
button_pipeline.click(fn=process_pipeline, inputs=inputs, outputs=[gallery_words, result_gallery]).then
block.launch(share=True)