Andranik Sargsyan commited on
Commit
1df97f6
β€’
1 Parent(s): fd3e2fa

add saving/recovering tmp user data for faster processing

Browse files
Files changed (3) hide show
  1. app.py +77 -25
  2. assets/sr_info.png +3 -0
  3. lib/methods/sr.py +8 -3
app.py CHANGED
@@ -75,11 +75,57 @@ def set_model_from_name(inp_model_name):
75
  inp_model = inpainting_models[inp_model_name]
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def rasg_run(
79
  use_painta, prompt, input, seed, eta,
80
  negative_prompt, positive_prompt, ddim_steps,
81
  guidance_scale=7.5,
82
- batch_size=1
83
  ):
84
  torch.cuda.empty_cache()
85
 
@@ -119,15 +165,18 @@ def rasg_run(
119
  dilation=12
120
  )
121
  blended_images.append(blended_image)
122
- inpainted_images.append(inpainted_image.numpy()[0])
 
 
 
123
 
124
- return blended_images, inpainted_images
125
 
126
 
127
  def sd_run(use_painta, prompt, input, seed, eta,
128
  negative_prompt, positive_prompt, ddim_steps,
129
  guidance_scale=7.5,
130
- batch_size=1
131
  ):
132
  torch.cuda.empty_cache()
133
 
@@ -167,32 +216,37 @@ def sd_run(use_painta, prompt, input, seed, eta,
167
  dilation=12
168
  )
169
  blended_images.append(blended_image)
170
- inpainted_images.append(inpainted_image.numpy()[0])
171
 
172
- return blended_images, inpainted_images
 
 
 
173
 
174
 
175
  def upscale_run(
176
- prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index,
177
  negative_prompt='',
178
  positive_prompt=', high resolution professional photo'
179
  ):
 
 
 
 
 
180
  torch.cuda.empty_cache()
181
 
182
  seed = int(seed)
183
  img_index = int(img_index)
184
-
185
  img_index = 0 if img_index < 0 else img_index
186
  img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
187
- img_info = gallery[img_index if img_index >= 0 else 0]
188
- inpainted_image = image_from_url_text(img_info)
189
- lr_image = IImage(inpainted_image)
190
- hr_image = IImage(input['image']).resize(2048)
191
- hr_mask = IImage(input['mask']).resize(2048)
192
  output_image = sr.run(
193
  sr_model,
194
  sam_predictor,
195
- lr_image,
196
  hr_image,
197
  hr_mask,
198
  prompt=prompt + positive_prompt,
@@ -203,8 +257,8 @@ def upscale_run(
203
  seed=seed,
204
  use_sam_mask=use_sam_mask
205
  )
206
- output_image.info = input['image'].info # save metadata
207
- return output_image, output_image
208
 
209
 
210
  def switch_run(use_rasg, model_name, *args):
@@ -316,8 +370,7 @@ with gr.Blocks(css='style.css') as demo:
316
  [input, prompt, example_container]
317
  )
318
 
319
- mock_output_gallery = gr.Gallery([], columns = 4, visible=False)
320
- mock_hires = gr.Image(label = "__MHRO__", visible = False)
321
  html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
322
 
323
  inpaint_btn.click(
@@ -334,25 +387,24 @@ with gr.Blocks(css='style.css') as demo:
334
  positive_prompt,
335
  ddim_steps,
336
  guidance_scale,
337
- batch_size
 
338
  ],
339
- outputs=[output_gallery, mock_output_gallery],
340
  api_name="inpaint"
341
  )
342
  upscale_btn.click(
343
  fn=upscale_run,
344
  inputs=[
345
- prompt,
346
- input,
347
  ddim_steps,
348
  seed,
349
  use_sam_mask,
350
- mock_output_gallery,
351
  html_info
352
  ],
353
- outputs=[hires_image, mock_hires],
354
  api_name="upscale",
355
- _js="function(a, b, c, d, e, f, g){ return [a, b, c, d, e, f, selected_gallery_index()] }",
356
  )
357
 
358
  demo.queue(max_size=20)
 
75
  inp_model = inpainting_models[inp_model_name]
76
 
77
 
78
+ def save_user_session(hr_image, hr_mask, lr_results, prompt, session_id=None):
79
+ if session_id == '':
80
+ session_id = str(uuid.uuid4())
81
+
82
+ tmp_dir = Path(TMP_DIR)
83
+ session_dir = tmp_dir / session_id
84
+ session_dir.mkdir(exist_ok=True, parents=True)
85
+
86
+ hr_image.save(session_dir / 'hr_image.png')
87
+ hr_mask.save(session_dir / 'hr_mask.png')
88
+
89
+ lr_results_dir = session_dir / 'lr_results'
90
+ if lr_results_dir.exists():
91
+ shutil.rmtree(lr_results_dir)
92
+ lr_results_dir.mkdir(parents=True)
93
+ for i, lr_result in enumerate(lr_results):
94
+ lr_result.save(lr_results_dir / f'{i}.png')
95
+
96
+ with open(session_dir / 'prompt.txt', 'w') as f:
97
+ f.write(prompt)
98
+
99
+ return session_id
100
+
101
+
102
+ def recover_user_session(session_id):
103
+ if session_id == '':
104
+ return None, None, []
105
+
106
+ tmp_dir = Path(TMP_DIR)
107
+ session_dir = tmp_dir / session_id
108
+ lr_results_dir = session_dir / 'lr_results'
109
+
110
+ hr_image = Image.open(session_dir / 'hr_image.png')
111
+ hr_mask = Image.open(session_dir / 'hr_mask.png')
112
+
113
+ lr_result_paths = list(lr_results_dir.glob('*.png'))
114
+ gallery = []
115
+ for lr_result_path in sorted(lr_result_paths):
116
+ gallery.append(Image.open(lr_result_path))
117
+
118
+ with open(session_dir / 'prompt.txt', "r") as f:
119
+ prompt = f.read()
120
+
121
+ return hr_image, hr_mask, gallery, prompt
122
+
123
+
124
  def rasg_run(
125
  use_painta, prompt, input, seed, eta,
126
  negative_prompt, positive_prompt, ddim_steps,
127
  guidance_scale=7.5,
128
+ batch_size=1, session_id=''
129
  ):
130
  torch.cuda.empty_cache()
131
 
 
165
  dilation=12
166
  )
167
  blended_images.append(blended_image)
168
+ inpainted_images.append(inpainted_image.pil())
169
+
170
+ session_id = save_user_session(
171
+ input['image'], input['mask'], inpainted_images, prompt, session_id=session_id)
172
 
173
+ return blended_images, session_id
174
 
175
 
176
  def sd_run(use_painta, prompt, input, seed, eta,
177
  negative_prompt, positive_prompt, ddim_steps,
178
  guidance_scale=7.5,
179
+ batch_size=1, session_id=''
180
  ):
181
  torch.cuda.empty_cache()
182
 
 
216
  dilation=12
217
  )
218
  blended_images.append(blended_image)
219
+ inpainted_images.append(inpainted_image.pil())
220
 
221
+ session_id = save_user_session(
222
+ input['image'], input['mask'], inpainted_images, prompt, session_id=session_id)
223
+
224
+ return blended_images, session_id
225
 
226
 
227
  def upscale_run(
228
+ ddim_steps, seed, use_sam_mask, session_id, img_index,
229
  negative_prompt='',
230
  positive_prompt=', high resolution professional photo'
231
  ):
232
+ hr_image, hr_mask, gallery, prompt = recover_user_session(session_id)
233
+
234
+ if len(gallery) == 0:
235
+ return Image.open('./assets/sr_info.png')
236
+
237
  torch.cuda.empty_cache()
238
 
239
  seed = int(seed)
240
  img_index = int(img_index)
241
+
242
  img_index = 0 if img_index < 0 else img_index
243
  img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
244
+ inpainted_image = gallery[img_index if img_index >= 0 else 0]
245
+
 
 
 
246
  output_image = sr.run(
247
  sr_model,
248
  sam_predictor,
249
+ inpainted_image,
250
  hr_image,
251
  hr_mask,
252
  prompt=prompt + positive_prompt,
 
257
  seed=seed,
258
  use_sam_mask=use_sam_mask
259
  )
260
+
261
+ return output_image
262
 
263
 
264
  def switch_run(use_rasg, model_name, *args):
 
370
  [input, prompt, example_container]
371
  )
372
 
373
+ session_id = gr.Textbox(value='', visible=False)
 
374
  html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
375
 
376
  inpaint_btn.click(
 
387
  positive_prompt,
388
  ddim_steps,
389
  guidance_scale,
390
+ batch_size,
391
+ session_id
392
  ],
393
+ outputs=[output_gallery, session_id],
394
  api_name="inpaint"
395
  )
396
  upscale_btn.click(
397
  fn=upscale_run,
398
  inputs=[
 
 
399
  ddim_steps,
400
  seed,
401
  use_sam_mask,
402
+ session_id,
403
  html_info
404
  ],
405
+ outputs=[hires_image],
406
  api_name="upscale",
407
+ _js="function(a, b, c, d, e){ return [a, b, c, d, selected_gallery_index()] }",
408
  )
409
 
410
  demo.queue(max_size=20)
assets/sr_info.png ADDED

Git LFS Details

  • SHA256: 2f79345db6231a0f325c265ddc7567121ac20e76dd09a7f0b0c53525d60f32a1
  • Pointer size: 130 Bytes
  • Size of remote file: 36.9 kB
lib/methods/sr.py CHANGED
@@ -73,6 +73,11 @@ def run(
73
  negative_prompt = '',
74
  use_sam_mask = False
75
  ):
 
 
 
 
 
76
  torch.manual_seed(seed)
77
  dtype = ddim.vae.encoder.conv_in.weight.dtype
78
  device = ddim.vae.encoder.conv_in.weight.device
@@ -143,6 +148,6 @@ def run(
143
  fake_img=hr_result,
144
  mask=hr_mask_orig.alpha().data[0]
145
  )
146
- return Image.fromarray(hr_result)
147
- else:
148
- return Image.fromarray(hr_result)
 
73
  negative_prompt = '',
74
  use_sam_mask = False
75
  ):
76
+ hr_image_info = hr_image.info
77
+ lr_image = IImage(lr_image)
78
+ hr_image = IImage(hr_image).resize(2048)
79
+ hr_mask = IImage(hr_mask).resize(2048)
80
+
81
  torch.manual_seed(seed)
82
  dtype = ddim.vae.encoder.conv_in.weight.dtype
83
  device = ddim.vae.encoder.conv_in.weight.device
 
148
  fake_img=hr_result,
149
  mask=hr_mask_orig.alpha().data[0]
150
  )
151
+ hr_result = Image.fromarray(hr_result)
152
+ hr_result.info = hr_image_info # save metadata
153
+ return hr_result