File size: 4,162 Bytes
e8ce90b
 
5c762ce
e8ce90b
 
 
5c762ce
 
e8ce90b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c762ce
 
 
 
 
 
 
 
c3014da
5c762ce
 
c3014da
5c762ce
 
 
 
 
c3014da
 
5c762ce
 
 
 
 
 
c3014da
 
5c762ce
 
 
 
 
 
 
c3014da
5c762ce
 
c3014da
5c762ce
 
 
 
c3014da
 
 
 
 
5c762ce
c3014da
5c762ce
 
c3014da
5c762ce
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import subprocess
from pathlib import Path

import gradio as gr

from demo import SdmCompressionDemo

dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
BK_SDM_CONFIG_URL: str = os.getenv('CHECKPOINT_CONFIG', None)
BK_SDM_TORCH_CKPT_URL: str = os.getenv('CHECKPOINT_PYTORCH_BIN', None)
assert BK_SDM_CONFIG_URL is not None
assert BK_SDM_TORCH_CKPT_URL is not None

subprocess.call(
    f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
    shell=True
)
subprocess.call(
    f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
    shell=True
)

if __name__ == "__main__":
    servicer = SdmCompressionDemo()
    example_list = servicer.get_example_list()

    with gr.Blocks(theme='nota-ai/theme') as demo:
        gr.Markdown(Path('docs/header.md').read_text())
        gr.Markdown(Path('docs/description.md').read_text())
        with gr.Row():
            with gr.Column(variant='panel', scale=30):

                text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")

                with gr.Row().style(equal_height=True):
                    generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
                    generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")

                with gr.Accordion("Advanced Settings", open=False):
                    negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
                    with gr.Row():
                        guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
                        steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
                        seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)

                with gr.Tab("Example Prompts"):
                    examples = gr.Examples(examples=example_list, inputs=[text])

            with gr.Column(variant='panel', scale=35):
                # Define original model output components
                gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
                original_model_output = gr.Image(label="Original Model")
                with gr.Row().style(equal_height=True):
                    original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                    original_model_error = gr.Markdown()

            with gr.Column(variant='panel', scale=35):
                # Define compressed model output components
                gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
                compressed_model_output = gr.Image(label="Compressed Model")
                with gr.Row().style(equal_height=True):
                    compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                    compressed_model_error = gr.Markdown()

        inputs = [text, negative, guidance_scale, steps, seed]

        # Click the generate button for original model
        original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
        text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
        generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)

        # Click the generate button for compressed model
        compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
        text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
        generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)

        gr.Markdown(Path('docs/footer.md').read_text())

    demo.queue(concurrency_count=1)
    # demo.launch()
    demo.launch(share=True, auth=("test", "testasdf@@19"))