hysts HF staff commited on
Commit
61d3740
·
1 Parent(s): 9c4b85e

Split the demo code

Browse files
Files changed (1) hide show
  1. app.py +72 -64
app.py CHANGED
@@ -55,64 +55,41 @@ def update_scheduler_type(name: str) -> dict:
55
  value=20)
56
 
57
 
58
- def main():
59
- parser = argparse.ArgumentParser()
60
- parser.add_argument('--device', type=str, default='cpu')
61
- args = parser.parse_args()
62
-
63
- model = Model(args.device)
64
-
65
- with gr.Blocks(css='style.css') as demo:
66
- gr.Markdown(TITLE)
67
-
68
- with gr.Tabs():
69
- with gr.TabItem('Simple Mode'):
70
- run_button_simple = gr.Button('Generate')
71
- result_simple = gr.Image(show_label=False,
72
- elem_id='result-grid')
73
-
74
- with gr.TabItem('Advanced Mode'):
75
- gr.Markdown(DESCRIPTION)
76
-
77
- with gr.Row():
78
- with gr.Column():
79
- with gr.Group():
80
- model_name = gr.Dropdown(
81
- model.MODEL_NAMES,
82
- value=model.MODEL_NAMES[0],
83
- label='Model',
84
- interactive=False)
85
- scheduler_type = gr.Radio(
86
- choices=['DDPM', 'DDIM', 'PNDM'],
87
- value='DDIM',
88
- label='Scheduler')
89
- num_steps = gr.Slider(1,
90
- 200,
91
- step=1,
92
- value=20,
93
- label='Number of Steps')
94
- seed = gr.Slider(0,
95
- 100000,
96
- step=1,
97
- value=1234,
98
- label='Seed')
99
- run_button = gr.Button('Run')
100
- with gr.Column():
101
- result = gr.Image(show_label=False, elem_id='result')
102
-
103
- with gr.TabItem('Sample Images'):
104
- with gr.Row():
105
- model_name2 = gr.Dropdown([
106
- 'ddpm-128-exp000 (DDPM)',
107
- 'ddpm-128-exp000 (DDIM, 20 steps)',
108
- ],
109
- value='ddpm-128-exp000 (DDPM)',
110
- label='Model')
111
- with gr.Row():
112
- text = get_sample_image_markdown(model_name2.value)
113
- sample_images = gr.Markdown(text)
114
-
115
- gr.Markdown(FOOTER)
116
 
117
  model_name.change(fn=model.set_pipeline,
118
  inputs=[
@@ -130,9 +107,6 @@ def main():
130
  scheduler_type,
131
  ],
132
  outputs=None)
133
- run_button_simple.click(fn=model.run_simple,
134
- inputs=None,
135
- outputs=result_simple)
136
  run_button.click(fn=model.run,
137
  inputs=[
138
  model_name,
@@ -141,10 +115,44 @@ def main():
141
  seed,
142
  ],
143
  outputs=result)
144
- model_name2.change(fn=get_sample_image_markdown,
145
- inputs=model_name2,
146
- outputs=sample_images)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  demo.launch(enable_queue=True, share=False)
149
 
150
 
 
55
  value=20)
56
 
57
 
58
+ def create_simple_demo(model: Model) -> gr.Blocks:
59
+ with gr.Blocks() as demo:
60
+ run_button = gr.Button('Generate')
61
+ result = gr.Image(show_label=False, elem_id='result-grid')
62
+ run_button.click(fn=model.run_simple, inputs=None, outputs=result)
63
+ return demo
64
+
65
+
66
+ def create_advanced_demo(model: Model) -> gr.Blocks:
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown(DESCRIPTION)
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+ with gr.Group():
73
+ model_name = gr.Dropdown(model.MODEL_NAMES,
74
+ value=model.MODEL_NAMES[0],
75
+ label='Model',
76
+ interactive=False)
77
+ scheduler_type = gr.Radio(choices=['DDPM', 'DDIM', 'PNDM'],
78
+ value='DDIM',
79
+ label='Scheduler')
80
+ num_steps = gr.Slider(1,
81
+ 200,
82
+ step=1,
83
+ value=20,
84
+ label='Number of Steps')
85
+ seed = gr.Slider(0,
86
+ 100000,
87
+ step=1,
88
+ value=1234,
89
+ label='Seed')
90
+ run_button = gr.Button('Run')
91
+ with gr.Column():
92
+ result = gr.Image(show_label=False, elem_id='result')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  model_name.change(fn=model.set_pipeline,
95
  inputs=[
 
107
  scheduler_type,
108
  ],
109
  outputs=None)
 
 
 
110
  run_button.click(fn=model.run,
111
  inputs=[
112
  model_name,
 
115
  seed,
116
  ],
117
  outputs=result)
118
+ return demo
 
 
119
 
120
+
121
+ def create_sample_image_view_demo() -> gr.Blocks:
122
+ with gr.Blocks() as demo:
123
+ with gr.Row():
124
+ model_name = gr.Dropdown([
125
+ 'ddpm-128-exp000 (DDPM)',
126
+ 'ddpm-128-exp000 (DDIM, 20 steps)',
127
+ ],
128
+ value='ddpm-128-exp000 (DDPM)',
129
+ label='Model')
130
+ with gr.Row():
131
+ text = get_sample_image_markdown(model_name.value)
132
+ sample_images = gr.Markdown(text)
133
+
134
+ model_name.change(fn=get_sample_image_markdown,
135
+ inputs=model_name,
136
+ outputs=sample_images)
137
+ return demo
138
+
139
+
140
+ def main():
141
+ parser = argparse.ArgumentParser()
142
+ parser.add_argument('--device', type=str, default='cpu')
143
+ args = parser.parse_args()
144
+ model = Model(args.device)
145
+
146
+ with gr.Blocks(css='style.css') as demo:
147
+ gr.Markdown(TITLE)
148
+ with gr.Tabs():
149
+ with gr.TabItem('Simple Mode'):
150
+ create_simple_demo(model)
151
+ with gr.TabItem('Advanced Mode'):
152
+ create_advanced_demo(model)
153
+ with gr.TabItem('Sample Images'):
154
+ create_sample_image_view_demo()
155
+ gr.Markdown(FOOTER)
156
  demo.launch(enable_queue=True, share=False)
157
 
158