Spaces:
Runtime error
Runtime error
File size: 4,185 Bytes
e8ce90b 5c762ce e8ce90b 5c762ce e8ce90b 690f7ad d39bc83 db4b904 d39bc83 e8ce90b 5c762ce 33a2e0a 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce c3014da 5c762ce db4b904 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import subprocess
from pathlib import Path
import gradio as gr
from demo import SdmCompressionDemo
dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
BK_SDM_CONFIG_URL: str = os.getenv('BK_SDM_CONFIG_URL', None)
BK_SDM_TORCH_CKPT_URL: str = os.getenv('BK_SDM_TORCH_CKPT_URL', None)
assert BK_SDM_CONFIG_URL is not None
assert BK_SDM_TORCH_CKPT_URL is not None
subprocess.call(
f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
shell=True
)
subprocess.call(
f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
shell=True
)
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
servicer = SdmCompressionDemo(device)
example_list = servicer.get_example_list()
with gr.Blocks(theme='nota-ai/theme') as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel', scale=30):
text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
with gr.Row().style(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
with gr.Tab("Example Prompts"):
examples = gr.Examples(examples=example_list, inputs=[text])
with gr.Column(variant='panel', scale=35):
# Define original model output components
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
original_model_output = gr.Image(label="Original Model")
with gr.Row().style(equal_height=True):
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
original_model_error = gr.Markdown()
with gr.Column(variant='panel', scale=35):
# Define compressed model output components
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
compressed_model_output = gr.Image(label="Compressed Model")
with gr.Row().style(equal_height=True):
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
compressed_model_error = gr.Markdown()
inputs = [text, negative, guidance_scale, steps, seed]
# Click the generate button for original model
original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
# Click the generate button for compressed model
compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue(concurrency_count=1)
# demo.launch()
demo.launch()
|