brayden-gg commited on
Commit
1b7b649
1 Parent(s): 5e527f9

switched from global vars to State

Browse files
Files changed (1) hide show
  1. app.py +188 -151
app.py CHANGED
@@ -7,171 +7,184 @@ from DataLoader import DataLoader
7
  import convenience
8
  import gradio as gr
9
 
10
- device = 'cpu'
11
- num_samples = 10
12
-
13
- net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
14
-
15
- if not torch.cuda.is_available():
16
- net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"])
17
-
18
-
19
- dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
20
-
21
-
22
- writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151]
23
- all_loaded_data = []
24
- chosen_writers = [120, 80]
25
- avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &"
26
- avail_char_list = avail_char.split(" ")
27
- for writer_id in chosen_writers:
28
- loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
29
- all_loaded_data.append(loaded_data)
30
-
31
- default_loaded_data = all_loaded_data[-1]
32
-
33
- # data for writer interpolation
34
- writer_words = ["hello", "world"]
35
- writer_mean_Ws = []
36
- all_word_writer_Ws = []
37
- all_word_writer_Cs = []
38
- writer_weight = 0.7
39
- writer_svg = None
40
-
41
- # data for char interpolation
42
- blend_chars = ["y", "s"]
43
- char_mean_global_W = None
44
- char_weight = 0.7
45
- default_mean_global_W = convenience.get_mean_global_W(net, default_loaded_data, device)
46
- char_Ws = default_mean_global_W.reshape(1, 1, convenience.L)
47
- char_Cs = all_Cs = torch.zeros(1, 2, convenience.L, convenience.L)
48
- char_svg = None
49
-
50
- # data for MDN
51
- mdn_words = ["hello", "world"]
52
- mdn_mean_global_W = None
53
- all_word_mdn_Ws = []
54
- all_word_mdn_Cs = []
55
- mdn_svg = None
56
-
57
- def update_writer_word(target_word):
58
- writer_words.clear()
59
- for word in target_word.split(" "):
60
- writer_words.append(word)
61
-
62
- all_word_writer_Ws.clear()
63
- all_word_writer_Cs.clear()
64
- for word in writer_words:
65
- all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
66
- all_word_writer_Ws.append(all_writer_Ws)
67
- all_word_writer_Cs.append(all_writer_Cs)
68
-
69
- return update_writer_slider(writer_weight)
70
-
71
-
72
- # for writer interpolation
73
- def update_writer_slider(val):
74
- global writer_weight
75
- global writer_svg
76
- writer_weight = val
77
- weights = [1 - writer_weight, writer_weight]
78
 
 
79
  net.clamp_mdn = 0
80
- writer_svg = convenience.draw_words_svg(writer_words, all_word_writer_Ws, all_word_writer_Cs, weights, net)
81
- return gr.HTML.update(value=writer_svg.tostring()), gr.Slider.update(visible=False), gr.Button.update(visible=True)
82
 
83
-
84
- def update_chosen_writers(writer1, writer2):
85
- net.clamp_mdn = 0
86
- chosen_writers[0], chosen_writers[1] = int(writer1.split(" ")[1]), int(writer2.split(" ")[1])
87
-
88
- all_loaded_data.clear()
89
  for writer_id in chosen_writers:
90
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
91
  all_loaded_data.append(loaded_data)
92
 
93
- writer_mean_Ws.clear()
94
  for loaded_data in all_loaded_data:
95
  mean_global_W = convenience.get_mean_global_W(net, loaded_data, device)
96
- writer_mean_Ws.append(mean_global_W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return gr.Slider.update(label=f"{writer1} vs. {writer2}"), *update_writer_slider(writer_weight)
99
 
100
- def update_writer_download():
101
  writer_svg.saveas("./DSD_writer_interpolation.svg")
102
  return gr.File.update(value="./DSD_writer_interpolation.svg", visible=True), gr.Button.update(visible=False)
103
 
104
  # for character blend
 
 
 
 
 
 
105
 
106
- def update_char_slider(weight):
107
- """Generates an image of handwritten text based on target_sentence"""
108
- global char_weight
109
- global char_svg
110
 
 
 
111
  net.clamp_mdn = 0
112
-
113
- char_weight = weight
114
  character_weights = [1 - weight, weight]
115
 
116
  all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
117
  all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
118
- char_svg = convenience.commands_to_svg(all_commands, 750, 160, 375)
119
- return gr.HTML.update(value=char_svg.tostring()), gr.Slider.update(visible=False), gr.Button.update(visible=True)
120
-
121
-
122
- def update_blend_chars(c1, c2):
123
- global blend_chars
124
- blend_chars[0], blend_chars[1] = c1, c2
125
-
126
- for i in range(2): # get corners of grid
127
- _, char_matrix = convenience.get_DSD(net, blend_chars[i], default_mean_global_W, [default_loaded_data], device)
128
- char_Cs[:, i, :, :] = char_matrix
129
 
130
- return gr.Slider.update(label=f"'{c1}' vs. '{c2}'")
131
-
132
- def update_char_download():
133
  char_svg.saveas("./DSD_char_interpolation.svg")
134
  return gr.File.update(value="./DSD_char_interpolation.svg", visible=True), gr.Button.update(visible=False)
135
 
136
  # for MDN
137
-
138
-
139
- def update_mdn_word(target_word):
140
- mdn_words.clear()
141
  for word in target_word.split(" "):
142
  mdn_words.append(word)
143
 
144
- all_word_mdn_Ws.clear()
145
- all_word_mdn_Cs.clear()
146
  for word in mdn_words:
147
  all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, default_mean_global_W, [default_loaded_data], device)
148
  all_word_mdn_Ws.append(all_writer_Ws)
149
  all_word_mdn_Cs.append(all_writer_Cs)
150
 
151
- return sample_mdn(net.scale_sd, net.clamp_mdn)
152
 
153
 
154
- def sample_mdn(maxs, maxr):
155
- global mdn_svg
156
  net.clamp_mdn = maxr
157
  net.scale_sd = maxs
158
- mdn_svg = convenience.draw_words_svg(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net)
159
- return gr.HTML.update(value=mdn_svg.tostring()), gr.Slider.update(visible=False), gr.Button.update(visible=True)
160
 
161
- def update_mdn_download():
162
  mdn_svg.saveas("./DSD_add_randomness.svg")
163
  return gr.File.update(value="./DSD_add_randomness.svg", visible=True), gr.Button.update(visible=False)
164
 
165
- update_writer_word(" ".join(writer_words))
166
- update_chosen_writers(f"Writer {chosen_writers[0]}", f"Writer {chosen_writers[1]}")
167
 
168
- update_mdn_word(" ".join(writer_words))
169
- update_blend_chars(*blend_chars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Tabs():
173
  with gr.TabItem("Blend Writers"):
174
- target_word = gr.Textbox(label="Target Word", value=" ".join(writer_words), max_lines=1)
175
  with gr.Row():
176
  left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0]
177
  right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1]
@@ -180,70 +193,94 @@ with gr.Blocks() as demo:
180
  with gr.Column():
181
  writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer")
182
  with gr.Row():
183
- writer_submit = gr.Button("Submit")
184
- with gr.Row():
185
- writer_slider = gr.Slider(0, 1, value=writer_weight, label="Style 120 vs. Style 80")
186
  with gr.Row():
187
- writer_default_image = update_writer_slider(writer_weight)
188
  writer_output = gr.HTML(writer_default_image[0]["value"])
189
  with gr.Row():
190
  writer_download_btn = gr.Button("Save to SVG file")
 
191
  writer_download = gr.File(interactive=False, show_label=False, visible=False)
192
- writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output, writer_download, writer_download_btn], show_progress=False)
193
- writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output, writer_download, writer_download_btn], show_progress=False)
194
- target_word.submit(fn=update_writer_word, inputs=[target_word], outputs=[writer_output, writer_download, writer_download_btn], show_progress=False)
195
-
196
- writer1.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output, writer_download, writer_download_btn])
197
- writer2.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output, writer_download, writer_download_btn])
198
- writer_download_btn.click(fn=update_writer_download, inputs=[], outputs=[writer_download, writer_download_btn])
199
- writer_download_btn.style(full_width="true")
 
 
 
 
 
 
 
 
 
200
  with gr.TabItem("Blend Characters"):
201
  with gr.Row():
202
  with gr.Column():
203
- char1 = gr.Dropdown(choices=avail_char_list, value=blend_chars[0], label="Character 1")
204
  with gr.Column():
205
- char2 = gr.Dropdown(choices=avail_char_list, value=blend_chars[1], label="Character 2")
206
  with gr.Row():
207
- char_submit_button = gr.Button(value="Submit")
208
  with gr.Row():
209
- char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
210
- with gr.Row():
211
- char_default_image = update_char_slider(char_weight)
212
  char_output = gr.HTML(char_default_image[0]["value"])
213
  with gr.Row():
214
  char_download_btn = gr.Button("Save to SVG file")
 
215
  char_download = gr.File(interactive=False, show_label=False, visible=False)
216
 
217
- char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output, char_download, char_download_btn], show_progress=False)
 
 
218
 
219
- char1.change(fn=update_blend_chars, inputs=[char1, char2], outputs=[char_slider])
220
- char2.change(fn=update_blend_chars, inputs=[char1, char2], outputs=[char_slider])
 
 
 
 
221
 
222
- char_submit_button.click(fn=update_char_slider, inputs=[char_slider], outputs=[char_output, char_download, char_download_btn], show_progress=False)
 
 
223
 
224
- char_download_btn.click(fn=update_char_download, inputs=[], outputs=[char_download, char_download_btn], show_progress=True)
225
- char_download_btn.style(full_width="true")
226
  with gr.TabItem("Add Randomness"):
227
- mdn_word = gr.Textbox(label="Target Word", value=" ".join(mdn_words), max_lines=1)
228
  with gr.Row():
229
  with gr.Column():
230
- max_rand = gr.Slider(0, 1, value=net.clamp_mdn, label="Maximum Randomness")
231
  with gr.Column():
232
- scale_rand = gr.Slider(0, 3, value=net.scale_sd, label="Scale of Randomness")
233
  with gr.Row():
234
  mdn_sample_button = gr.Button(value="Resample")
235
  with gr.Row():
236
- default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
237
  mdn_output = gr.HTML(default_im[0]["value"])
238
  with gr.Row():
239
  randomness_download_btn = gr.Button("Save to SVG file")
240
  randomness_download = gr.File(interactive=False, show_label=False, visible=False)
241
 
242
- max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output, randomness_download, randomness_download_btn], show_progress=False)
243
- scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output, randomness_download, randomness_download_btn], show_progress=False)
244
- mdn_sample_button.click(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output, randomness_download, randomness_download_btn], show_progress=False)
245
- mdn_word.submit(fn=update_mdn_word, inputs=[mdn_word], outputs=[mdn_output, randomness_download, randomness_download_btn], show_progress=False)
246
-
247
- randomness_download_btn.click(fn=update_mdn_download, inputs=[], outputs=[randomness_download, randomness_download_btn])
 
 
 
 
 
 
 
 
 
 
 
248
  randomness_download_btn.style(full_width="true")
249
  demo.launch()
 
7
  import convenience
8
  import gradio as gr
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def update_chosen_writers(writer1, writer2, weight, words, all_loaded_data):
12
  net.clamp_mdn = 0
13
+ chosen_writers = [int(writer1.split(" ")[1]), int(writer2.split(" ")[1])]
 
14
 
15
+ all_loaded_data = []
 
 
 
 
 
16
  for writer_id in chosen_writers:
17
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
18
  all_loaded_data.append(loaded_data)
19
 
20
+ writer_mean_Ws = []
21
  for loaded_data in all_loaded_data:
22
  mean_global_W = convenience.get_mean_global_W(net, loaded_data, device)
23
+ writer_mean_Ws.append(mean_global_W.detach())
24
+
25
+ return gr.Slider.update(label=f"{writer1} vs. {writer2}"), chosen_writers, writer_mean_Ws, *update_writer_word(" ".join(words), writer_mean_Ws, all_loaded_data, weight)
26
+
27
+ def update_writer_word(target_word, writer_mean_Ws, all_loaded_data, writer_weight, device="cpu"):
28
+ words = []
29
+ for word in target_word.split(" "):
30
+ if len(word) > 0:
31
+ words.append(word)
32
+
33
+ word_Ws = []
34
+ word_Cs = []
35
+ for word in words:
36
+ writer_Ws, writer_Cs = convenience.get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
37
+ word_Ws.append(writer_Ws)
38
+ word_Cs.append(writer_Cs)
39
+
40
+ if len(words) == 0:
41
+ word_Ws.append(torch.tensor([]))
42
+ word_Cs.append(torch.tensor([]))
43
+
44
+ return words, word_Ws, word_Cs, *update_writer_slider(writer_weight, words, word_Ws, word_Cs)
45
+
46
+ def update_writer_slider(weight, words, all_word_Ws, all_word_Cs):
47
+ weights = [1 - weight, weight]
48
+ net.clamp_mdn = 0
49
+ svg = convenience.draw_words_svg(words, all_word_Ws, all_word_Cs, weights, net)
50
+ return gr.HTML.update(value=svg.tostring()), gr.File.update(visible=False), gr.Button.update(visible=True), weight, svg
51
 
 
52
 
53
+ def update_writer_download(writer_svg):
54
  writer_svg.saveas("./DSD_writer_interpolation.svg")
55
  return gr.File.update(value="./DSD_writer_interpolation.svg", visible=True), gr.Button.update(visible=False)
56
 
57
  # for character blend
58
+ def update_blend_chars(c1, c2, weight, char_Ws):
59
+ blend_chars = [c1, c2]
60
+ char_Cs = torch.zeros(1, 2, convenience.L, convenience.L)
61
+ for i in range(2): # get corners of grid
62
+ _, char_matrix = convenience.get_DSD(net, blend_chars[i], default_mean_global_W, [default_loaded_data], device)
63
+ char_Cs[:, i, :, :] = char_matrix
64
 
65
+ return gr.Slider.update(label=f"'{c1}' vs. '{c2}'"), char_Cs.detach(), blend_chars, *update_char_slider(weight, char_Ws, char_Cs, blend_chars)
 
 
 
66
 
67
+ def update_char_slider(weight, char_Ws, char_Cs, blend_chars):
68
+ """Generates an image of handwritten text based on target_sentence"""
69
  net.clamp_mdn = 0
 
 
70
  character_weights = [1 - weight, weight]
71
 
72
  all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
73
  all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
74
+ svg = convenience.commands_to_svg(all_commands, 750, 160, 375)
75
+ return gr.HTML.update(value=svg.tostring()), gr.File.update(visible=False), gr.Button.update(visible=True), weight, svg
 
 
 
 
 
 
 
 
 
76
 
77
+ def update_char_download(char_svg):
 
 
78
  char_svg.saveas("./DSD_char_interpolation.svg")
79
  return gr.File.update(value="./DSD_char_interpolation.svg", visible=True), gr.Button.update(visible=False)
80
 
81
  # for MDN
82
+ def update_mdn_word(target_word, scale_sd, clamp_mdn):
83
+ mdn_words = []
 
 
84
  for word in target_word.split(" "):
85
  mdn_words.append(word)
86
 
87
+ all_word_mdn_Ws = []
88
+ all_word_mdn_Cs = []
89
  for word in mdn_words:
90
  all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, default_mean_global_W, [default_loaded_data], device)
91
  all_word_mdn_Ws.append(all_writer_Ws)
92
  all_word_mdn_Cs.append(all_writer_Cs)
93
 
94
+ return mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, *sample_mdn(scale_sd, clamp_mdn, mdn_words, all_word_mdn_Ws, all_word_mdn_Cs)
95
 
96
 
97
+ def sample_mdn(maxs, maxr, mdn_words, all_word_mdn_Ws, all_word_mdn_Cs):
 
98
  net.clamp_mdn = maxr
99
  net.scale_sd = maxs
100
+ svg = convenience.draw_words_svg(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net)
101
+ return gr.HTML.update(value=svg.tostring()), gr.File.update(visible=False), gr.Button.update(visible=True), maxr, maxs, svg
102
 
103
+ def update_mdn_download(mdn_svg):
104
  mdn_svg.saveas("./DSD_add_randomness.svg")
105
  return gr.File.update(value="./DSD_add_randomness.svg", visible=True), gr.Button.update(visible=False)
106
 
107
+ device = 'cpu'
108
+ num_samples = 10
109
 
110
+ net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
111
+
112
+ if not torch.cuda.is_available():
113
+ net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"])
114
+
115
+
116
+ dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
117
+
118
+ writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151]
119
+ all_loaded_data_DEFAULT = []
120
+ chosen_writers_DEFAULT = [120, 80]
121
+ avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &"
122
+ avail_char_list = avail_char.split(" ")
123
+ for writer_id in chosen_writers_DEFAULT:
124
+ loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
125
+ all_loaded_data_DEFAULT.append(loaded_data)
126
+
127
+ default_loaded_data = all_loaded_data_DEFAULT[-1]
128
+ default_mean_global_W = convenience.get_mean_global_W(net, default_loaded_data, device)
129
+
130
+ # data for writer interpolation
131
+ writer_words_DEFAULT = ["hello", "world"]
132
+ writer_mean_Ws_DEFAULT = []
133
+ writer_all_word_Ws_DEFAULT = []
134
+ writer_all_word_Cs_DEFAULT = []
135
+ writer_weight_DEFAULT = 0.7
136
+ writer_svg_DEFAULT = None
137
+
138
+ # data for char interpolation
139
+ char_chosen_DEFAULT = ["y", "s"]
140
+ char_mean_global_W_DEFAULT = None
141
+ char_weight_DEFAULT = 0.7
142
+ char_Ws_DEFAULT = default_mean_global_W.reshape(1, 1, convenience.L)
143
+ char_Cs_DEFAULT = None
144
+ char_svg_DEFAULT = None
145
+
146
+ # # data for MDN
147
+ mdn_words_DEFAULT = ["hello", "world"]
148
+ all_word_mdn_Ws_DEFAULT = None
149
+ all_word_mdn_Cs_DEFAULT = None
150
+ clamp_mdn_DEFAULT = 0.5
151
+ scale_sd_DEFAULT = 1
152
+ mdn_svg_DEFAULT = None
153
+
154
+ _wrds, writer_all_word_Ws_DEFAULT, writer_all_word_Cs_DEFAULT, _html, _file, _btn, _wt, _svg = update_writer_word(" ".join(writer_words_DEFAULT), writer_mean_Ws_DEFAULT, all_loaded_data_DEFAULT, writer_weight_DEFAULT)
155
+ _sldr, _wrtrs, writer_mean_Ws_DEFAULT, _wrds, _waww, _wawc, _html, _file, _btn, _wt, writer_svg_DEFAULT = update_chosen_writers(f"Writer {chosen_writers_DEFAULT[0]}", f"Writer {chosen_writers_DEFAULT[1]}", writer_weight_DEFAULT, writer_words_DEFAULT, all_loaded_data_DEFAULT)
156
+
157
+ _wrds, all_word_mdn_Ws_DEFAULT, all_word_mdn_Cs_DEFAULT, _html, _file, _btn, _maxr, _maxs, mdn_svg_DEFAULT = update_mdn_word(" ".join(mdn_words_DEFAULT), scale_sd_DEFAULT, clamp_mdn_DEFAULT)
158
+ _sldr, char_Cs_DEFAULT, _chrs, _html, _file, _btn, _wght, char_svg_DEFAULT = update_blend_chars(*char_chosen_DEFAULT, char_weight_DEFAULT, char_Ws_DEFAULT)
159
 
160
  with gr.Blocks() as demo:
161
+ all_loaded_data_var = gr.State(all_loaded_data_DEFAULT)
162
+ chosen_writers_var = gr.State(chosen_writers_DEFAULT)
163
+ # data for writer interpolation
164
+ writer_words_var = gr.State(writer_words_DEFAULT)
165
+ writer_mean_Ws_var = gr.State(writer_mean_Ws_DEFAULT)
166
+ writer_all_word_Ws_var = gr.State([e.detach() for e in writer_all_word_Ws_DEFAULT])
167
+ writer_all_word_Cs_var = gr.State([e.detach() for e in writer_all_word_Cs_DEFAULT])
168
+ writer_weight_var = gr.State(writer_weight_DEFAULT)
169
+ writer_svg_var = gr.State(writer_svg_DEFAULT)
170
+ # data for char interpolation
171
+ char_chosen_var = gr.State(char_chosen_DEFAULT)
172
+ char_mean_global_W_var = gr.State(char_mean_global_W_DEFAULT)
173
+ char_weight_var = gr.State(char_weight_DEFAULT)
174
+ char_Ws_var = gr.State(char_Ws_DEFAULT.detach())
175
+ char_Cs_var = gr.State(char_Cs_DEFAULT.detach())
176
+ char_svg_var = gr.State(char_svg_DEFAULT)
177
+ # # data for MDN
178
+ mdn_words_var = gr.State(mdn_words_DEFAULT)
179
+ all_word_mdn_Ws_var = gr.State([e.detach() for e in all_word_mdn_Ws_DEFAULT])
180
+ all_word_mdn_Cs_var = gr.State([e.detach() for e in all_word_mdn_Cs_DEFAULT])
181
+ clamp_mdn_var = gr.State(clamp_mdn_DEFAULT)
182
+ scale_sd_var = gr.State(scale_sd_DEFAULT)
183
+ mdn_svg_var = gr.State(mdn_svg_DEFAULT)
184
+
185
  with gr.Tabs():
186
  with gr.TabItem("Blend Writers"):
187
+ target_word = gr.Textbox(label="Target Word", value=" ".join(writer_words_DEFAULT), max_lines=1)
188
  with gr.Row():
189
  left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0]
190
  right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1]
 
193
  with gr.Column():
194
  writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer")
195
  with gr.Row():
196
+ writer_slider = gr.Slider(0, 1, value=writer_weight_DEFAULT, label="Style 120 vs. Style 80")
 
 
197
  with gr.Row():
198
+ writer_default_image = update_writer_slider(writer_weight_DEFAULT, writer_words_DEFAULT, writer_all_word_Ws_DEFAULT, writer_all_word_Cs_DEFAULT)
199
  writer_output = gr.HTML(writer_default_image[0]["value"])
200
  with gr.Row():
201
  writer_download_btn = gr.Button("Save to SVG file")
202
+ writer_download_btn.style(full_width="true")
203
  writer_download = gr.File(interactive=False, show_label=False, visible=False)
204
+
205
+ writer_slider.change(fn=update_writer_slider,
206
+ inputs=[writer_slider, writer_words_var, writer_all_word_Ws_var, writer_all_word_Cs_var],
207
+ outputs=[writer_output, writer_download, writer_download_btn, writer_weight_var, writer_svg_var], show_progress=False)
208
+ target_word.submit(fn=update_writer_word,
209
+ inputs=[target_word, writer_mean_Ws_var, all_loaded_data_var, writer_weight_var],
210
+ outputs=[writer_words_var, writer_all_word_Ws_var, writer_all_word_Cs_var, writer_output, writer_download, writer_download_btn, writer_weight_var, writer_svg_var], show_progress=False)
211
+ writer1.change(fn=update_chosen_writers,
212
+ inputs=[writer1, writer2, writer_weight_var, writer_words_var, all_loaded_data_var],
213
+ outputs=[writer_slider, chosen_writers_var, writer_mean_Ws_var, writer_words_var, writer_all_word_Ws_var, writer_all_word_Cs_var, writer_output, writer_download, writer_download_btn, writer_weight_var, writer_svg_var])
214
+ writer2.change(fn=update_chosen_writers,
215
+ inputs=[writer1, writer2, writer_weight_var, writer_words_var, all_loaded_data_var],
216
+ outputs=[writer_slider, chosen_writers_var, writer_mean_Ws_var, writer_words_var, writer_all_word_Ws_var, writer_all_word_Cs_var, writer_output, writer_download, writer_download_btn, writer_weight_var, writer_svg_var])
217
+ writer_download_btn.click(fn=update_writer_download,
218
+ inputs=[writer_svg_var],
219
+ outputs=[writer_download, writer_download_btn])
220
+
221
  with gr.TabItem("Blend Characters"):
222
  with gr.Row():
223
  with gr.Column():
224
+ char1 = gr.Dropdown(choices=avail_char_list, value=char_chosen_DEFAULT[0], label="Character 1")
225
  with gr.Column():
226
+ char2 = gr.Dropdown(choices=avail_char_list, value=char_chosen_DEFAULT[1], label="Character 2")
227
  with gr.Row():
228
+ char_slider = gr.Slider(0, 1, value=char_weight_DEFAULT, label=f"'{char_chosen_DEFAULT[0]}' vs. '{char_chosen_DEFAULT[1]}'")
229
  with gr.Row():
230
+ char_default_image = update_char_slider(char_weight_DEFAULT, char_Ws_DEFAULT, char_Cs_DEFAULT, char_chosen_DEFAULT)
 
 
231
  char_output = gr.HTML(char_default_image[0]["value"])
232
  with gr.Row():
233
  char_download_btn = gr.Button("Save to SVG file")
234
+ char_download_btn.style(full_width="true")
235
  char_download = gr.File(interactive=False, show_label=False, visible=False)
236
 
237
+ char_slider.change(fn=update_char_slider,
238
+ inputs=[char_slider, char_Ws_var, char_Cs_var, char_chosen_var],
239
+ outputs=[char_output, char_download, char_download_btn, char_weight_var, char_svg_var], show_progress=False)
240
 
241
+ char1.change(fn=update_blend_chars,
242
+ inputs=[char1, char2, char_weight_var, char_Ws_var],
243
+ outputs=[char_slider, char_Cs_var, char_chosen_var, char_output, char_download, char_download_btn, char_weight_var, char_svg_var])
244
+ char2.change(fn=update_blend_chars,
245
+ inputs=[char1, char2, char_weight_var, char_Ws_var],
246
+ outputs=[char_slider, char_Cs_var, char_chosen_var, char_output, char_download, char_download_btn, char_weight_var, char_svg_var])
247
 
248
+ char_download_btn.click(fn=update_char_download,
249
+ inputs=[char_svg_var],
250
+ outputs=[char_download, char_download_btn], show_progress=True)
251
 
 
 
252
  with gr.TabItem("Add Randomness"):
253
+ mdn_word = gr.Textbox(label="Target Word", value=" ".join(mdn_words_DEFAULT), max_lines=1)
254
  with gr.Row():
255
  with gr.Column():
256
+ max_rand = gr.Slider(0, 1, value=clamp_mdn_DEFAULT, label="Maximum Randomness")
257
  with gr.Column():
258
+ scale_rand = gr.Slider(0, 3, value=scale_sd_DEFAULT, label="Scale of Randomness")
259
  with gr.Row():
260
  mdn_sample_button = gr.Button(value="Resample")
261
  with gr.Row():
262
+ default_im = sample_mdn(scale_sd_DEFAULT, clamp_mdn_DEFAULT, mdn_words_DEFAULT, all_word_mdn_Ws_DEFAULT, all_word_mdn_Cs_DEFAULT)
263
  mdn_output = gr.HTML(default_im[0]["value"])
264
  with gr.Row():
265
  randomness_download_btn = gr.Button("Save to SVG file")
266
  randomness_download = gr.File(interactive=False, show_label=False, visible=False)
267
 
268
+ max_rand.change(fn=sample_mdn,
269
+ inputs=[scale_rand, max_rand, mdn_words_var, all_word_mdn_Ws_var, all_word_mdn_Cs_var],
270
+ outputs=[mdn_output, randomness_download, randomness_download_btn, clamp_mdn_var, scale_sd_var, mdn_svg_var], show_progress=False)
271
+ scale_rand.change(fn=sample_mdn,
272
+ inputs=[scale_rand, max_rand, mdn_words_var, all_word_mdn_Ws_var, all_word_mdn_Cs_var],
273
+ outputs=[mdn_output, randomness_download, randomness_download_btn, clamp_mdn_var, scale_sd_var, mdn_svg_var], show_progress=False)
274
+ mdn_sample_button.click(fn=sample_mdn,
275
+ inputs=[scale_rand, max_rand, mdn_words_var, all_word_mdn_Ws_var, all_word_mdn_Cs_var],
276
+ outputs=[mdn_output, randomness_download, randomness_download_btn, clamp_mdn_var, scale_sd_var, mdn_svg_var], show_progress=False)
277
+
278
+ mdn_word.submit(fn=update_mdn_word,
279
+ inputs=[mdn_word, scale_sd_var, clamp_mdn_var],
280
+ outputs=[mdn_words_var, all_word_mdn_Ws_var, all_word_mdn_Cs_var, mdn_output, randomness_download, randomness_download_btn, clamp_mdn_var, scale_sd_var, mdn_svg_var], show_progress=False)
281
+
282
+ randomness_download_btn.click(fn=update_mdn_download,
283
+ inputs=[mdn_svg_var],
284
+ outputs=[randomness_download, randomness_download_btn])
285
  randomness_download_btn.style(full_width="true")
286
  demo.launch()