yizhangliu commited on
Commit
ede9250
1 Parent(s): 8802f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -43
app.py CHANGED
@@ -71,10 +71,46 @@ def read_content(file_path):
71
  content = f.read()
72
  return content
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  model = None
75
 
76
- def model_process(image, mask):
77
- global model
78
 
79
  if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
80
  # rotate image
@@ -116,13 +152,13 @@ def model_process(image, mask):
116
  if config.sd_seed == -1:
117
  config.sd_seed = random.randint(1, 999999999)
118
 
119
- print(f"Origin image shape_0_: {original_shape} / {size_limit}")
120
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
121
- print(f"Resized image shape_1_: {image.shape}")
122
 
123
- print(f"mask image shape_0_: {mask.shape} / {type(mask)}")
124
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
125
- print(f"mask image shape_1_: {mask.shape} / {type(mask)}")
126
 
127
  if model is None:
128
  return None
@@ -131,6 +167,19 @@ def model_process(image, mask):
131
  torch.cuda.empty_cache()
132
 
133
  image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return image # image
135
 
136
  model = ModelManager(
@@ -139,7 +188,9 @@ model = ModelManager(
139
  )
140
 
141
  image_type = 'pil' # filepath'
142
- def predict(input):
 
 
143
  if image_type == 'filepath':
144
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
145
  origin_image_bytes = read_content(input["image"])
@@ -152,25 +203,25 @@ def predict(input):
152
  mask_pil = input['mask']
153
  image = np.array(image_pil)
154
  mask = np.array(mask_pil.convert("L"))
155
- output = model_process(image, mask)
156
  return output
157
 
158
  css = '''
159
  .container {max-width: 100%;margin: auto;padding-top: 1.5rem}
160
  .output-image, .input-image, .image-preview {height: 600px !important;object-fit: contain}
 
161
  #image_upload{min-height:610px}
162
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 620px}
163
  #image_output{margin: 0 auto; text-align: center;width:640px}
164
- #prompt-container{margin: 0 auto; text-align: center;width:200px;border-width:5px;border-color:#2c9748}
165
- #mask_radio .gr-form{background:transparent; border: none}
166
- #mask_radio .gr-form{background:transparent; border: none; color:#00ff00}
167
- #word_mask{margin-top: .75em !important}
168
- #word_mask textarea:disabled{opacity: 0.3}
169
  .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
170
  .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
171
  .dark .footer {border-color: #303030}
172
  .dark .footer>p {background: #0b0f19}
173
- .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
174
  #image_upload .touch-none{display: flex}
175
  @keyframes spin {
176
  from {
@@ -180,40 +231,30 @@ css = '''
180
  transform: rotate(360deg);
181
  }
182
  }
183
- #share-btn-container {
184
- display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
185
- }
186
- #share-btn {
187
- all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
188
- }
189
- #share-btn * {
190
- all: unset;
191
- }
192
- #share-btn-container div:nth-child(-n+2){
193
- width: auto !important;
194
- min-height: 0px !important;
195
- }
196
- #share-btn-container .wrap {
197
- display: none !important;
198
- }
199
  '''
200
 
201
  image_blocks = gr.Blocks(css=css)
202
  with image_blocks as demo:
203
  with gr.Group():
204
- with gr.Box():
205
- with gr.Row():
206
  with gr.Column():
207
- image = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}', label="Upload").style(mobile_collapse=False)
208
- with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
209
- btn_in = gr.Button("Erase(↓)").style(
210
- margin=True,
211
- rounded=(True, True, True, True),
212
- full_width=True,
213
- )
214
- with gr.Row():
 
 
 
 
 
215
  with gr.Column():
216
- image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(width=640)
217
- btn_in.click(fn=predict, inputs=[image], outputs=[image_out])
 
218
 
219
- image_blocks.launch()
 
71
  content = f.read()
72
  return content
73
 
74
+ def get_image_enhancer(scale = 2, device='cuda:0'):
75
+ from basicsr.archs.rrdbnet_arch import RRDBNet
76
+ from realesrgan import RealESRGANer
77
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
78
+ from gfpgan import GFPGANer
79
+
80
+ realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
81
+ num_block=23, num_grow_ch=32, scale=4
82
+ )
83
+ netscale = scale
84
+
85
+ model_realesrgan = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
86
+ upsampler = RealESRGANer(
87
+ scale=netscale,
88
+ model_path=model_realesrgan,
89
+ model=realesrgan_model,
90
+ tile=0,
91
+ tile_pad=10,
92
+ pre_pad=0,
93
+ half=False if device=='cpu' else True,
94
+ device=device
95
+ )
96
+
97
+ model_GFPGAN = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
98
+ img_enhancer = GFPGANer(
99
+ model_path=model_GFPGAN,
100
+ upscale=scale,
101
+ arch='clean',
102
+ channel_multiplier=2,
103
+ bg_upsampler=upsampler,
104
+ device=device
105
+ )
106
+ return img_enhancer
107
+
108
+ image_enhancer = get_image_enhancer(scale = 1, device=device)
109
+
110
  model = None
111
 
112
+ def model_process(image, mask, img_enhancer):
113
+ global model,image_enhancer
114
 
115
  if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
116
  # rotate image
 
152
  if config.sd_seed == -1:
153
  config.sd_seed = random.randint(1, 999999999)
154
 
155
+ logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
156
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
157
+ logger.info(f"Resized image shape_1_: {image.shape}")
158
 
159
+ logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
160
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
161
+ logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
162
 
163
  if model is None:
164
  return None
 
167
  torch.cuda.empty_cache()
168
 
169
  image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
170
+
171
+ if image_enhancer is not None and img_enhancer:
172
+ start = time.time()
173
+ input_img_rgb = np.array(image)
174
+ input_img_bgr = input_img_rgb[...,[2,1,0]]
175
+ _, _, enhance_img = image_enhancer.enhance(input_img_bgr, has_aligned=False,
176
+ only_center_face=False, paste_back=True)
177
+ input_img_rgb = enhance_img[...,[2,1,0]]
178
+ img_enhance = Image.fromarray(np.uint8(input_img_rgb))
179
+ image = img_enhance
180
+ log_info = f"image_enhancer_: {(time.time() - start) * 1000}ms, {res_np_img.shape} "
181
+ logger.info(log_info)
182
+
183
  return image # image
184
 
185
  model = ModelManager(
 
188
  )
189
 
190
  image_type = 'pil' # filepath'
191
+ def predict(input, img_enhancer):
192
+ if input is None:
193
+ return None
194
  if image_type == 'filepath':
195
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
196
  origin_image_bytes = read_content(input["image"])
 
203
  mask_pil = input['mask']
204
  image = np.array(image_pil)
205
  mask = np.array(mask_pil.convert("L"))
206
+ output = model_process(image, mask, img_enhancer)
207
  return output
208
 
209
  css = '''
210
  .container {max-width: 100%;margin: auto;padding-top: 1.5rem}
211
  .output-image, .input-image, .image-preview {height: 600px !important;object-fit: contain}
212
+ #work-container {min-width: min(160px, 100%) !important;flex-grow: 0 !important}
213
  #image_upload{min-height:610px}
214
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 620px}
215
  #image_output{margin: 0 auto; text-align: center;width:640px}
216
+ #erase-container{margin: 0 auto; text-align: center;width:150px;border-width:5px;border-color:#2c9748}
217
+ #enhancer-checkbox{width:520px}
218
+ #enhancer-tip{width:450px}
219
+ #enhancer-tip-div{text-align: left}
220
+ #prompt-container{margin: 0 auto; text-align: center;width:fit-content;min-width: min(150px, 100%);flex-grow: 0; flex-wrap: nowrap;}
221
  .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
222
  .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
223
  .dark .footer {border-color: #303030}
224
  .dark .footer>p {background: #0b0f19}
 
225
  #image_upload .touch-none{display: flex}
226
  @keyframes spin {
227
  from {
 
231
  transform: rotate(360deg);
232
  }
233
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  '''
235
 
236
  image_blocks = gr.Blocks(css=css)
237
  with image_blocks as demo:
238
  with gr.Group():
239
+ with gr.Box(elem_id="work-container"):
240
+ with gr.Row(elem_id="input-container"):
241
  with gr.Column():
242
+ image = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}',
243
+ label="Upload(载入图片)", show_label=True).style(mobile_collapse=False)
244
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
245
+ with gr.Column(elem_id="erase-container"):
246
+ btn_erase = gr.Button(value = "Erase(擦除↓)",elem_id="erase_btn").style(
247
+ margin=True,
248
+ rounded=(True, True, True, True),
249
+ full_width=True,
250
+ ).style(width=100)
251
+ with gr.Column(elem_id="enhancer-checkbox"):
252
+ enhancer_label = 'Enhanced image(processing is very slow, please check only for blurred images)【增强图像(处理很慢,请仅针对模糊图像做勾选)】'
253
+ img_enhancer = gr.Checkbox(label=enhancer_label).style(width=150)
254
+ with gr.Row(elem_id="output-container"):
255
  with gr.Column():
256
+ image_out = gr.Image(label="Result", elem_id="image_output", visible=True).style(width=640)
257
+
258
+ btn_erase.click(fn=predict, inputs=[image, img_enhancer], outputs=[image_out])
259
 
260
+ image_blocks.launch()