afeng commited on
Commit
850ea5b
1 Parent(s): 8fa9206
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app copy.py +349 -0
  3. app.py +60 -66
  4. segment.py +25 -23
.gitignore CHANGED
@@ -4,6 +4,7 @@ example1_example2_512/
4
  example1_example2_1024/
5
  example1/
6
  old/
 
7
 
8
  out_active.png
9
  out_mask.png
 
4
  example1_example2_1024/
5
  example1/
6
  old/
7
+ example_tmp/
8
 
9
  out_active.png
10
  out_mask.png
app copy.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import copy
4
+ from PIL import Image
5
+ import matplotlib
6
+ import numpy as np
7
+ import gradio as gr
8
+ from utils import load_mask, load_mask_edit
9
+ from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
10
+ from pathlib import Path
11
+ import subprocess
12
+ from PIL import Image
13
+
14
+ LENGTH=512 #length of the square area displaying/editing images
15
+ TRANSPARENCY = 150 # transparency of the mask in display
16
+
17
+ def add_mask(mask_np_list_updated, mask_label_list):
18
+ mask_new = np.zeros_like(mask_np_list_updated[0])
19
+ mask_np_list_updated.append(mask_new)
20
+ mask_label_list.append("new")
21
+ return mask_np_list_updated, mask_label_list
22
+
23
+ def create_segmentation(mask_np_list):
24
+ viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
25
+ segmentation = 0
26
+ for i, m in enumerate(mask_np_list):
27
+ color = matplotlib.colors.to_rgb(viridis(i))
28
+ color_mat = np.ones_like(m)
29
+ color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
30
+ color_mat = color_mat * m[:,:,np.newaxis]
31
+ segmentation += color_mat
32
+ segmentation = Image.fromarray(np.uint8(segmentation*255))
33
+ return segmentation
34
+
35
+ def load_mask_ui(input_folder,load_edit = False):
36
+ if not load_edit:
37
+ mask_list, mask_label_list = load_mask(input_folder)
38
+ else:
39
+ mask_list, mask_label_list = load_mask_edit(input_folder)
40
+
41
+ mask_np_list = []
42
+ for m in mask_list:
43
+ mask_np_list. append( m.cpu().numpy())
44
+
45
+ return mask_np_list, mask_label_list
46
+
47
+ def load_image_ui(input_folder, load_edit):
48
+ try:
49
+ for img_path in Path(input_folder).iterdir():
50
+ if img_path.name in ["img.png", "img_1024.png", "img_512.png"]:
51
+ image = Image.open(img_path)
52
+ mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
53
+ image = image.convert('RGB')
54
+ segmentation = create_segmentation(mask_np_list)
55
+ return image, segmentation, mask_np_list, mask_label_list, image
56
+ except:
57
+ print("Image folder invalid: The folder should contain image.png")
58
+ return None, None, None, None, None
59
+
60
+ def run_segmentation(input_folder):
61
+ subprocess.run(["python", "segment.py" , "--name={}".format(input_folder)])
62
+ return
63
+
64
+
65
+
66
+ def run_edit_text(
67
+ input_folder,
68
+ num_tokens,
69
+ num_sampling_steps,
70
+ strength,
71
+ edge_thickness,
72
+ tgt_prompt,
73
+ tgt_idx,
74
+ guidance_scale
75
+ ):
76
+ subprocess.run(["python",
77
+ "main.py" ,
78
+ "--text",
79
+ "--name={}".format(input_folder),
80
+ "--dpm={}".format("sd"),
81
+ "--resolution={}".format(512),
82
+ "--load_trained",
83
+ "--num_tokens={}".format(num_tokens),
84
+ "--seed={}".format(2024),
85
+ "--guidance_scale={}".format(guidance_scale),
86
+ "--num_sampling_step={}".format(num_sampling_steps),
87
+ "--strength={}".format(strength),
88
+ "--edge_thickness={}".format(edge_thickness),
89
+ "--num_imgs={}".format(2),
90
+ "--tgt_prompt={}".format(tgt_prompt) ,
91
+ "--tgt_index={}".format(tgt_idx)
92
+ ])
93
+
94
+ return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
95
+
96
+
97
+ def run_optimization(
98
+ input_folder,
99
+ num_tokens,
100
+ embedding_learning_rate,
101
+ max_emb_train_steps,
102
+ diffusion_model_learning_rate,
103
+ max_diffusion_train_steps,
104
+ train_batch_size,
105
+ gradient_accumulation_steps
106
+ ):
107
+ subprocess.run(["python",
108
+ "main.py" ,
109
+ "--name={}".format(input_folder),
110
+ "--dpm={}".format("sd"),
111
+ "--resolution={}".format(512),
112
+ "--num_tokens={}".format(num_tokens),
113
+ "--embedding_learning_rate={}".format(embedding_learning_rate),
114
+ "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
115
+ "--max_emb_train_steps={}".format(max_emb_train_steps),
116
+ "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
117
+ "--train_batch_size={}".format(train_batch_size),
118
+ "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
119
+
120
+ ])
121
+ return
122
+
123
+
124
+ def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
125
+ backimg_solid_np = np.array(backimg)
126
+ bimg = backimg.copy()
127
+ fimg = foreimg.copy()
128
+ fimg.putalpha(transparency)
129
+ bimg.paste(fimg, (0,0), fimg)
130
+
131
+ bimg_np = np.array(bimg)
132
+ mask_np = mask_np[:,:,np.newaxis]
133
+ try:
134
+ new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
135
+ return Image.fromarray(new_img_np)
136
+ except:
137
+ import pdb; pdb.set_trace()
138
+
139
+ def show_segmentation(image, segmentation, flag):
140
+ if flag is False:
141
+ flag = True
142
+ mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
143
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
144
+ return image_edit, flag
145
+ else:
146
+ flag = False
147
+ return image,flag
148
+
149
+ def edit_mask_add(canvas, image, idx, mask_np_list):
150
+ mask_sel = mask_np_list[idx]
151
+ mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
152
+ mask_np_list_updated = []
153
+ for midx, m in enumerate(mask_np_list):
154
+ if midx == idx:
155
+ mask_np_list_updated.append(mask_union(mask_sel, mask_new))
156
+ else:
157
+ mask_np_list_updated.append(m)
158
+
159
+ priority_list = [0 for _ in range(len(mask_np_list_updated))]
160
+ priority_list[idx] = 1
161
+ mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
162
+ mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
163
+ segmentation = create_segmentation(mask_np_list_updated)
164
+ image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
165
+ return mask_np_list_updated, image_edit
166
+
167
+ def slider_release(index, image, mask_np_list_updated, mask_label_list):
168
+ if index > len(mask_np_list_updated):
169
+ return image, "out of range"
170
+ else:
171
+ mask_np = mask_np_list_updated[index]
172
+ mask_label = mask_label_list[index]
173
+ segmentation = create_segmentation(mask_np_list_updated)
174
+ new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
175
+ return new_image, mask_label
176
+
177
+ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
178
+ try:
179
+ assert np.all(sum(mask_np_list_updated)==1)
180
+ except:
181
+ print("please check mask")
182
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
183
+ import pdb; pdb.set_trace()
184
+
185
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
186
+ # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
187
+ np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
188
+ savepath = os.path.join(input_folder, "seg_current.png")
189
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
190
+
191
+ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
192
+ try:
193
+ assert np.all(sum(mask_np_list_updated)==1)
194
+ except:
195
+ print("please check mask")
196
+ # plt.imsave( "out_mask.png", mask_list_edit[0])
197
+ import pdb; pdb.set_trace()
198
+ for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
199
+ np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
200
+ savepath = os.path.join(input_folder, "seg_edited.png")
201
+ visualize_mask_list_clean(mask_np_list_updated, savepath)
202
+
203
+ with gr.Blocks() as demo:
204
+ image = gr.State() # store mask
205
+ image_loaded = gr.State()
206
+ segmentation = gr.State()
207
+
208
+ mask_np_list = gr.State([])
209
+ mask_label_list = gr.State([])
210
+ mask_np_list_updated = gr.State([])
211
+ true = gr.State(True)
212
+ false = gr.State(False)
213
+
214
+
215
+ with gr.Row():
216
+ gr.Markdown("""# D-Edit""")
217
+
218
+ with gr.Tab(label="1 Edit mask"):
219
+ with gr.Row():
220
+ with gr.Column():
221
+ canvas = gr.Image(value = None, type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
+ input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
+
224
+ segment_button = gr.Button("1.1 Run segmentation")
225
+ segment_button.click(run_segmentation,
226
+ [input_folder] ,
227
+ [] )
228
+
229
+
230
+ text_button = gr.Button("1.2 Load original masks")
231
+ text_button.click(load_image_ui,
232
+ [input_folder, false] ,
233
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
234
+
235
+ load_edit_button = gr.Button("1.2 Load edited masks")
236
+ load_edit_button.click(load_image_ui,
237
+ [input_folder, true] ,
238
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
239
+
240
+ show_segment = gr.Checkbox(label = "Show Segmentation")
241
+
242
+ flag = gr.State(False)
243
+ show_segment.select(show_segmentation,
244
+ [image_loaded, segmentation, flag],
245
+ [canvas, flag])
246
+
247
+ mask_np_list_updated = copy.deepcopy(mask_np_list)
248
+
249
+ with gr.Column():
250
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
+ slider = gr.Slider(0, 20, step=1, interactive=True)
252
+ label = gr.Textbox()
253
+ slider.release(slider_release,
254
+ inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
255
+ outputs= [canvas, label]
256
+ )
257
+ add_button = gr.Button("Add")
258
+ add_button.click( edit_mask_add,
259
+ [canvas, image_loaded, slider, mask_np_list_updated] ,
260
+ [mask_np_list_updated, canvas]
261
+ )
262
+
263
+ save_button2 = gr.Button("Set and Save as edited masks")
264
+ save_button2.click( save_as_edit_mask,
265
+ [mask_np_list_updated, mask_label_list, input_folder] ,
266
+ [] )
267
+
268
+ save_button = gr.Button("Set and Save as original masks")
269
+ save_button.click( save_as_orig_mask,
270
+ [mask_np_list_updated, mask_label_list, input_folder] ,
271
+ [] )
272
+
273
+ back_button = gr.Button("Back to current seg")
274
+ back_button.click( load_mask_ui,
275
+ [input_folder] ,
276
+ [ mask_np_list_updated,mask_label_list] )
277
+
278
+ add_mask_button = gr.Button("Add new empty mask")
279
+ add_mask_button.click(add_mask,
280
+ [mask_np_list_updated, mask_label_list] ,
281
+ [mask_np_list_updated, mask_label_list] )
282
+
283
+ with gr.Tab(label="2 Optimization"):
284
+ with gr.Row():
285
+ with gr.Column():
286
+ canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
+
288
+ with gr.Column():
289
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
290
+ num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
291
+ embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
292
+ max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
293
+
294
+ diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
295
+ max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
296
+
297
+ train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
298
+ gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
299
+
300
+ add_button = gr.Button("Run optimization")
301
+ add_button.click(run_optimization,
302
+ inputs = [
303
+ input_folder,
304
+ num_tokens,
305
+ embedding_learning_rate,
306
+ max_emb_train_steps,
307
+ diffusion_model_learning_rate,
308
+ max_diffusion_train_steps,
309
+ train_batch_size,gradient_accumulation_steps
310
+ ],
311
+ outputs = []
312
+ )
313
+
314
+
315
+ with gr.Tab(label="3 Editing"):
316
+ with gr.Tab(label="3.1 Text-based editing"):
317
+ canvas_text_edit = gr.State() # store mask
318
+ with gr.Row():
319
+ with gr.Column():
320
+ canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
322
+
323
+ with gr.Column():
324
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
325
+
326
+ tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
327
+ tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
328
+ guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
329
+ num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
330
+ edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
331
+ strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
332
+
333
+ add_button = gr.Button("Run Editing")
334
+ add_button.click(run_edit_text,
335
+ inputs = [
336
+ input_folder,
337
+ num_tokens,
338
+ num_sampling_steps,
339
+ strength,
340
+ edge_thickness,
341
+ tgt_prompt,
342
+ tgt_idx,
343
+ guidance_scale
344
+ ],
345
+ outputs = [canvas_text_edit]
346
+ )
347
+
348
+
349
+ demo.queue().launch(share=True, debug=True)
app.py CHANGED
@@ -57,12 +57,6 @@ def load_image_ui(input_folder, load_edit):
57
  print("Image folder invalid: The folder should contain image.png")
58
  return None, None, None, None, None
59
 
60
- def run_segmentation(input_folder):
61
- subprocess.run(["python", "segment.py" , "--name={}".format(input_folder)])
62
- return
63
-
64
-
65
-
66
  def run_edit_text(
67
  input_folder,
68
  num_tokens,
@@ -200,6 +194,8 @@ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
200
  savepath = os.path.join(input_folder, "seg_edited.png")
201
  visualize_mask_list_clean(mask_np_list_updated, savepath)
202
 
 
 
203
  with gr.Blocks() as demo:
204
  image = gr.State() # store mask
205
  image_loaded = gr.State()
@@ -211,22 +207,20 @@ with gr.Blocks() as demo:
211
  true = gr.State(True)
212
  false = gr.State(False)
213
 
214
-
215
  with gr.Row():
216
  gr.Markdown("""# D-Edit""")
217
 
218
  with gr.Tab(label="1 Edit mask"):
219
  with gr.Row():
220
  with gr.Column():
221
- canvas = gr.Image(value = None, type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
  input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
 
224
  segment_button = gr.Button("1.1 Run segmentation")
225
  segment_button.click(run_segmentation,
226
- [input_folder] ,
227
  [] )
228
-
229
-
230
  text_button = gr.Button("1.2 Load original masks")
231
  text_button.click(load_image_ui,
232
  [input_folder, false] ,
@@ -280,70 +274,70 @@ with gr.Blocks() as demo:
280
  [mask_np_list_updated, mask_label_list] ,
281
  [mask_np_list_updated, mask_label_list] )
282
 
283
- with gr.Tab(label="2 Optimization"):
284
- with gr.Row():
285
- with gr.Column():
286
- canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
 
288
- with gr.Column():
289
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
290
- num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
291
- embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
292
- max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
293
 
294
- diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
295
- max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
296
 
297
- train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
298
- gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
299
 
300
- add_button = gr.Button("Run optimization")
301
- add_button.click(run_optimization,
302
- inputs = [
303
- input_folder,
304
- num_tokens,
305
- embedding_learning_rate,
306
- max_emb_train_steps,
307
- diffusion_model_learning_rate,
308
- max_diffusion_train_steps,
309
- train_batch_size,gradient_accumulation_steps
310
- ],
311
- outputs = []
312
- )
313
 
314
 
315
- with gr.Tab(label="3 Editing"):
316
- with gr.Tab(label="3.1 Text-based editing"):
317
- canvas_text_edit = gr.State() # store mask
318
- with gr.Row():
319
- with gr.Column():
320
- canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
- # canvas_text_edit = gr.Gallery(label = "Edited results")
322
 
323
- with gr.Column():
324
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
325
 
326
- tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
327
- tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
328
- guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
329
- num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
330
- edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
331
- strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
332
 
333
- add_button = gr.Button("Run Editing")
334
- add_button.click(run_edit_text,
335
- inputs = [
336
- input_folder,
337
- num_tokens,
338
- num_sampling_steps,
339
- strength,
340
- edge_thickness,
341
- tgt_prompt,
342
- tgt_idx,
343
- guidance_scale
344
- ],
345
- outputs = [canvas_text_edit]
346
- )
347
 
348
 
349
  demo.queue().launch(share=True, debug=True)
 
57
  print("Image folder invalid: The folder should contain image.png")
58
  return None, None, None, None, None
59
 
 
 
 
 
 
 
60
  def run_edit_text(
61
  input_folder,
62
  num_tokens,
 
194
  savepath = os.path.join(input_folder, "seg_edited.png")
195
  visualize_mask_list_clean(mask_np_list_updated, savepath)
196
 
197
+
198
+ from segment import run_segmentation
199
  with gr.Blocks() as demo:
200
  image = gr.State() # store mask
201
  image_loaded = gr.State()
 
207
  true = gr.State(True)
208
  false = gr.State(False)
209
 
 
210
  with gr.Row():
211
  gr.Markdown("""# D-Edit""")
212
 
213
  with gr.Tab(label="1 Edit mask"):
214
  with gr.Row():
215
  with gr.Column():
216
+ canvas = gr.Image(value = None, type="pil", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
217
  input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
218
 
219
  segment_button = gr.Button("1.1 Run segmentation")
220
  segment_button.click(run_segmentation,
221
+ [canvas] ,
222
  [] )
223
+
 
224
  text_button = gr.Button("1.2 Load original masks")
225
  text_button.click(load_image_ui,
226
  [input_folder, false] ,
 
274
  [mask_np_list_updated, mask_label_list] ,
275
  [mask_np_list_updated, mask_label_list] )
276
 
277
+ # with gr.Tab(label="2 Optimization"):
278
+ # with gr.Row():
279
+ # with gr.Column():
280
+ # canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
281
 
282
+ # with gr.Column():
283
+ # gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
284
+ # num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
285
+ # embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
286
+ # max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
287
 
288
+ # diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
289
+ # max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
290
 
291
+ # train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
292
+ # gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
293
 
294
+ # add_button = gr.Button("Run optimization")
295
+ # add_button.click(run_optimization,
296
+ # inputs = [
297
+ # input_folder,
298
+ # num_tokens,
299
+ # embedding_learning_rate,
300
+ # max_emb_train_steps,
301
+ # diffusion_model_learning_rate,
302
+ # max_diffusion_train_steps,
303
+ # train_batch_size,gradient_accumulation_steps
304
+ # ],
305
+ # outputs = []
306
+ # )
307
 
308
 
309
+ # with gr.Tab(label="3 Editing"):
310
+ # with gr.Tab(label="3.1 Text-based editing"):
311
+ # canvas_text_edit = gr.State() # store mask
312
+ # with gr.Row():
313
+ # with gr.Column():
314
+ # canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
315
+ # # canvas_text_edit = gr.Gallery(label = "Edited results")
316
 
317
+ # with gr.Column():
318
+ # gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
319
 
320
+ # tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
321
+ # tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
322
+ # guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
323
+ # num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
324
+ # edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
325
+ # strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
326
 
327
+ # add_button = gr.Button("Run Editing")
328
+ # add_button.click(run_edit_text,
329
+ # inputs = [
330
+ # input_folder,
331
+ # num_tokens,
332
+ # num_sampling_steps,
333
+ # strength,
334
+ # edge_thickness,
335
+ # tgt_prompt,
336
+ # tgt_idx,
337
+ # guidance_scale
338
+ # ],
339
+ # outputs = [canvas_text_edit]
340
+ # )
341
 
342
 
343
  demo.queue().launch(share=True, debug=True)
segment.py CHANGED
@@ -32,7 +32,7 @@ def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
32
  image = np.array(Image.fromarray(image).resize((size, size)))
33
  return image
34
 
35
- def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False):
36
  if torch.max(segmentation)==torch.min(segmentation)==-1:
37
  print("nothing is detected!")
38
  noseg=True
@@ -88,28 +88,30 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
88
 
89
 
90
 
91
- parser = argparse.ArgumentParser()
92
- parser.add_argument("--name", type=str, default="obama")
93
- parser.add_argument("--size", type=int, default=512)
94
- parser.add_argument("--noseg", default=False, action="store_true" )
95
- args = parser.parse_args()
96
- base_folder_path = "."
97
 
98
- processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
99
- model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
100
- input_folder = os.path.join(base_folder_path, args.name )
101
- try:
102
- image = load_image(os.path.join(input_folder, "img.png" ), size = args.size)
103
- except:
104
- image = load_image(os.path.join(input_folder, "img.jpg" ), size = args.size)
105
 
106
- image =Image.fromarray(image)
107
- image.save(os.path.join(input_folder,"img_{}.png".format(args.size)))
108
- inputs = processor(image, return_tensors="pt")
109
- with torch.no_grad():
110
- outputs = model(**inputs)
111
 
112
- panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
113
- save_folder = os.path.join(base_folder_path, args.name)
114
- os.makedirs(save_folder, exist_ok=True)
115
- draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = args.noseg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  image = np.array(Image.fromarray(image).resize((size, size)))
33
  return image
34
 
35
+ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None):
36
  if torch.max(segmentation)==torch.min(segmentation)==-1:
37
  print("nothing is detected!")
38
  noseg=True
 
88
 
89
 
90
 
 
 
 
 
 
 
91
 
92
+ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
 
 
 
 
 
 
93
 
94
+ base_folder_path = "."
 
 
 
 
95
 
96
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
97
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
98
+
99
+
100
+ # input_folder = os.path.join(base_folder_path, name )
101
+ # try:
102
+ # image = load_image(os.path.join(input_folder, "img.png" ), size = size)
103
+ # except:
104
+ # image = load_image(os.path.join(input_folder, "img.jpg" ), size = size)
105
+ # image =Image.fromarray(image)
106
+ os.makedirs(name, exist_ok=True)
107
+ image.save(os.path.join(name,"img_{}.png".format(size)))
108
+ inputs = processor(image, return_tensors="pt")
109
+ with torch.no_grad():
110
+ outputs = model(**inputs)
111
+
112
+ panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
113
+ save_folder = os.path.join(base_folder_path, name)
114
+ os.makedirs(save_folder, exist_ok=True)
115
+ draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
116
+ print("Finish segment")
117
+ return