import gradio as gr import numpy as np import random from diffusers import DiffusionPipeline import torch from model.omegaconf_utils import load_from_config from model.dongba_dreamer import DongbaDreamer ################################################## device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): power_device = "GPU" else: power_device = "CPU (free space, very slow, around 200 seconds)" ################################################## config_files = ['config/dongba_dreamer.yaml'] config = load_from_config(config_files) # dongba dreamer dreamer = DongbaDreamer(config) ################################################## def process_pipeline(image_topic, canvas_width, canvas_height, num_words): dreamer.reset_logs() rlt = dreamer.process_words(image_topic, canvas_width, canvas_height, num_words) word_images = rlt['word_images'] image_prompt = rlt['image_prompt'] composition_image = rlt['composition_image'] sd_images = dreamer.process_sd(image_prompt, composition_image) return word_images, sd_images ################################################## CSS =""" .contain { display: flex; flex-direction: column; } #component-0 { height: 100%; } #output_box { flex-grow: 1; overflow: auto; } """ block = gr.Blocks(css=CSS).queue() with block: with gr.Row(): gr.Markdown(f"## Dongba Dreamer Showcase, running on {power_device}") with gr.Row(): with gr.Column(scale=1): image_topic = gr.Textbox(label="Image Topic", value="日出江花红胜火") with gr.Row(): canvas_width = gr.Number(label="Canvas Width (multiple of 512)", value=512) canvas_height = gr.Number(label="Canvas Height (multiple of 512)", value=512) num_words = gr.Number(label="Expected number of characters (0 = auto)", value=0) button_pipeline = gr.Button(value="Launch Dongba Dreamer") gallery_words = gr.Gallery(label='Dongba Words', show_label=False, elem_id="gallery", columns=4) with gr.Column(scale=2, elem_id="output_box"): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=2) with gr.Row(): logbox = gr.Textbox(label='Log', value=dreamer.get_logs, every=1, interactive=False) inputs = [image_topic, canvas_width, canvas_height, num_words] # button_words.click(fn=process_words, inputs=inputs, outputs=[gallery_words]) # button_sd.click(fn=process_sd, inputs=None, outputs=[result_gallery]) button_pipeline.click(fn=process_pipeline, inputs=inputs, outputs=[gallery_words, result_gallery]).then block.launch(share=True)