File size: 6,308 Bytes
9c00f5c
 
 
 
8be0786
 
9c00f5c
 
fde1134
55cf602
fde1134
ac14842
55cf602
592e1ab
55cf602
80597e4
 
9c00f5c
80597e4
9c00f5c
 
8fad05b
 
 
 
 
9c00f5c
 
 
ac14842
5d7586b
80597e4
9c00f5c
 
fde1134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80597e4
 
fde1134
 
 
 
9c00f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde1134
 
8be0786
80597e4
8be0786
 
9c00f5c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr

from model import AppModel

TITLE = '# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)'

DESCRIPTION = '''
The model accepts English or Chinese as input.
In general, Chinese input produces better results than English input.
By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) will be used as input. Since the translation model may mistranslate, you may want to use the translation results from other translation services.
'''
NOTES = '''
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogView2" />'


def set_example_text(example: list) -> list[dict]:
    return [
        gr.Textbox.update(value=example[0]),
        gr.Dropdown.update(value=example[1]),
    ]


def main():
    only_first_stage = False
    max_inference_batch_size = 8
    model = AppModel(max_inference_batch_size, only_first_stage)

    with gr.Blocks(css='style.css') as demo:

        with gr.Tabs():
            with gr.TabItem('Simple Mode'):
                gr.Markdown(TITLE)

                with gr.Row().style(mobile_collapse=False, equal_height=True):
                    text_simple = gr.Textbox(placeholder='Enter your prompt',
                                             show_label=False,
                                             max_lines=1).style(
                                                 border=(True, False, True,
                                                         True),
                                                 rounded=(True, False, False,
                                                          True),
                                                 container=False,
                                             )
                    run_button_simple = gr.Button('Run').style(
                        margin=False,
                        rounded=(False, True, True, False),
                    )
                result_grid_simple = gr.Image(show_label=False)

            with gr.TabItem('Advanced Mode'):
                gr.Markdown(TITLE)
                gr.Markdown(DESCRIPTION)

                with gr.Row():
                    with gr.Column():
                        with gr.Group():
                            text = gr.Textbox(label='Input Text')
                            translate = gr.Checkbox(
                                label='Translate to Chinese', value=False)
                            style = gr.Dropdown(choices=[
                                'none',
                                'mainbody',
                                'photo',
                                'flat',
                                'comics',
                                'oil',
                                'sketch',
                                'isometric',
                                'chinese',
                                'watercolor',
                            ],
                                                value='mainbody',
                                                label='Style')
                            seed = gr.Slider(0,
                                             100000,
                                             step=1,
                                             value=1234,
                                             label='Seed')
                            only_first_stage = gr.Checkbox(
                                label='Only First Stage',
                                value=only_first_stage,
                                visible=not only_first_stage)
                            num_images = gr.Slider(1,
                                                   16,
                                                   step=1,
                                                   value=4,
                                                   label='Number of Images')
                            run_button = gr.Button('Run')

                            with open('samples.txt') as f:
                                samples = [
                                    line.strip().split('\t')
                                    for line in f.readlines()
                                ]
                            examples = gr.Dataset(components=[text, style],
                                                  samples=samples)

                    with gr.Column():
                        with gr.Group():
                            translated_text = gr.Textbox(
                                label='Translated Text')
                            with gr.Tabs():
                                with gr.TabItem('Output (Grid View)'):
                                    result_grid = gr.Image(show_label=False)
                                with gr.TabItem('Output (Gallery)'):
                                    result_gallery = gr.Gallery(
                                        show_label=False)

                gr.Markdown(NOTES)

        gr.Markdown(FOOTER)

        run_button_simple.click(fn=model.run_simple,
                                inputs=text_simple,
                                outputs=result_grid_simple)
        run_button.click(fn=model.run_advanced,
                         inputs=[
                             text,
                             translate,
                             style,
                             seed,
                             only_first_stage,
                             num_images,
                         ],
                         outputs=[
                             translated_text,
                             result_grid,
                             result_gallery,
                         ])
        examples.click(fn=set_example_text,
                       inputs=examples,
                       outputs=examples.components,
                       queue=False)

    demo.launch(enable_queue=True)


if __name__ == '__main__':
    main()