yizhangliu commited on
Commit
924af64
1 Parent(s): 24a45a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -36
app.py CHANGED
@@ -1,25 +1,16 @@
1
  import gradio as gr
2
- import PIL
3
  from PIL import Image
4
  import numpy as np
5
- import os
6
  import uuid
7
  import torch
8
- from torch import autocast
9
  import cv2
10
- from io import BytesIO
11
-
12
- from matplotlib import pyplot as plt
13
- from torchvision import transforms
14
 
15
  import io
16
- import logging
17
  import multiprocessing
18
  import random
19
  import time
20
  import imghdr
21
- from pathlib import Path
22
- from typing import Union
23
  from loguru import logger
24
 
25
  from lama_cleaner.model_manager import ModelManager
@@ -33,7 +24,6 @@ try:
33
  except:
34
  pass
35
 
36
-
37
  from lama_cleaner.helper import (
38
  load_img,
39
  numpy_to_bytes,
@@ -58,19 +48,13 @@ HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD')
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
  print(f'device = {device}')
60
 
61
- def get_image_ext(img_bytes):
62
- w = imghdr.what("", img_bytes)
63
- if w is None:
64
- w = "jpeg"
65
- return w
66
-
67
- def read_content(file_path):
68
  """read the content of target file
69
  """
70
- with open(file_path, 'rb') as f:
71
  content = f.read()
72
  return content
73
-
74
  def get_image_enhancer(scale = 2, device='cuda:0'):
75
  from basicsr.archs.rrdbnet_arch import RRDBNet
76
  from realesrgan import RealESRGANer
@@ -105,7 +89,9 @@ def get_image_enhancer(scale = 2, device='cuda:0'):
105
  )
106
  return img_enhancer
107
 
108
- image_enhancer = get_image_enhancer(scale = 1, device=device)
 
 
109
 
110
  model = None
111
 
@@ -119,7 +105,7 @@ def model_process(image, mask, img_enhancer):
119
  original_shape = image.shape
120
  interpolation = cv2.INTER_CUBIC
121
 
122
- size_limit = 1080 #1080 # "Original"
123
  if size_limit == "Original":
124
  size_limit = max(image.shape)
125
  else:
@@ -193,10 +179,10 @@ def predict(input, img_enhancer):
193
  return None
194
  if image_type == 'filepath':
195
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
196
- origin_image_bytes = read_content(input["image"])
197
  print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
198
  image, _ = load_img(origin_image_bytes)
199
- mask, _ = load_img(read_content(input["mask"]), gray=True)
200
  elif image_type == 'pil':
201
  # input: {'image': pil, 'mask': pil}
202
  image_pil = input['image']
@@ -206,22 +192,17 @@ def predict(input, img_enhancer):
206
  output = model_process(image, mask, img_enhancer)
207
  return output
208
 
 
209
  css = '''
210
  .container {max-width: 100%;margin: auto;padding-top: 1.5rem}
211
- .output-image, .input-image, .image-preview {height: 600px !important;object-fit: contain}
212
  #work-container {min-width: min(160px, 100%) !important;flex-grow: 0 !important}
213
- #image_upload{min-height:610px}
214
- #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 620px}
215
  #image_output{margin: 0 auto; text-align: center;width:640px}
216
  #erase-container{margin: 0 auto; text-align: center;width:150px;border-width:5px;border-color:#2c9748}
217
  #enhancer-checkbox{width:520px}
218
  #enhancer-tip{width:450px}
219
  #enhancer-tip-div{text-align: left}
220
  #prompt-container{margin: 0 auto; text-align: center;width:fit-content;min-width: min(150px, 100%);flex-grow: 0; flex-wrap: nowrap;}
221
- .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
222
- .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
223
- .dark .footer {border-color: #303030}
224
- .dark .footer>p {background: #0b0f19}
225
  #image_upload .touch-none{display: flex}
226
  @keyframes spin {
227
  from {
@@ -232,15 +213,63 @@ css = '''
232
  }
233
  }
234
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  image_blocks = gr.Blocks(css=css)
237
  with image_blocks as demo:
238
- with gr.Group():
 
 
 
 
 
 
 
 
 
 
239
  with gr.Box(elem_id="work-container"):
240
  with gr.Row(elem_id="input-container"):
241
  with gr.Column():
242
  image = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}',
243
- label="Upload(载入图片)", show_label=True).style(mobile_collapse=False)
244
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
245
  with gr.Column(elem_id="erase-container"):
246
  btn_erase = gr.Button(value = "Erase(擦除↓)",elem_id="erase_btn").style(
@@ -248,13 +277,15 @@ with image_blocks as demo:
248
  rounded=(True, True, True, True),
249
  full_width=True,
250
  ).style(width=100)
251
- with gr.Column(elem_id="enhancer-checkbox"):
252
  enhancer_label = 'Enhanced image(processing is very slow, please check only for blurred images)【增强图像(处理很慢,请仅针对模糊图像做勾选)】'
253
  img_enhancer = gr.Checkbox(label=enhancer_label).style(width=150)
254
  with gr.Row(elem_id="output-container"):
255
  with gr.Column():
256
- image_out = gr.Image(label="Result", elem_id="image_output", visible=True).style(width=640)
257
 
258
  btn_erase.click(fn=predict, inputs=[image, img_enhancer], outputs=[image_out])
259
-
 
 
260
  image_blocks.launch()
 
1
  import gradio as gr
 
2
  from PIL import Image
3
  import numpy as np
4
+ import os,sys
5
  import uuid
6
  import torch
 
7
  import cv2
 
 
 
 
8
 
9
  import io
 
10
  import multiprocessing
11
  import random
12
  import time
13
  import imghdr
 
 
14
  from loguru import logger
15
 
16
  from lama_cleaner.model_manager import ModelManager
 
24
  except:
25
  pass
26
 
 
27
  from lama_cleaner.helper import (
28
  load_img,
29
  numpy_to_bytes,
 
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  print(f'device = {device}')
50
 
51
+ def read_content(file_path: str) -> str:
 
 
 
 
 
 
52
  """read the content of target file
53
  """
54
+ with open(file_path, 'r', encoding='utf-8') as f:
55
  content = f.read()
56
  return content
57
+
58
  def get_image_enhancer(scale = 2, device='cuda:0'):
59
  from basicsr.archs.rrdbnet_arch import RRDBNet
60
  from realesrgan import RealESRGANer
 
89
  )
90
  return img_enhancer
91
 
92
+ image_enhancer = None
93
+ if sys.platform == 'linux':
94
+ image_enhancer = get_image_enhancer(scale = 1, device=device)
95
 
96
  model = None
97
 
 
105
  original_shape = image.shape
106
  interpolation = cv2.INTER_CUBIC
107
 
108
+ size_limit = 1080
109
  if size_limit == "Original":
110
  size_limit = max(image.shape)
111
  else:
 
179
  return None
180
  if image_type == 'filepath':
181
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
182
+ origin_image_bytes = open(input["image"], 'rb').read()
183
  print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
184
  image, _ = load_img(origin_image_bytes)
185
+ mask, _ = load_img(open(input["mask"], 'rb').read(), gray=True)
186
  elif image_type == 'pil':
187
  # input: {'image': pil, 'mask': pil}
188
  image_pil = input['image']
 
192
  output = model_process(image, mask, img_enhancer)
193
  return output
194
 
195
+
196
  css = '''
197
  .container {max-width: 100%;margin: auto;padding-top: 1.5rem}
198
+ #begin-btn {color: blue; font-size:20px;}
199
  #work-container {min-width: min(160px, 100%) !important;flex-grow: 0 !important}
 
 
200
  #image_output{margin: 0 auto; text-align: center;width:640px}
201
  #erase-container{margin: 0 auto; text-align: center;width:150px;border-width:5px;border-color:#2c9748}
202
  #enhancer-checkbox{width:520px}
203
  #enhancer-tip{width:450px}
204
  #enhancer-tip-div{text-align: left}
205
  #prompt-container{margin: 0 auto; text-align: center;width:fit-content;min-width: min(150px, 100%);flex-grow: 0; flex-wrap: nowrap;}
 
 
 
 
206
  #image_upload .touch-none{display: flex}
207
  @keyframes spin {
208
  from {
 
213
  }
214
  }
215
  '''
216
+ set_page_elements = """async () => {
217
+ function isMobile() {
218
+ try {
219
+ document.createEvent("TouchEvent"); return true;
220
+ } catch(e) {
221
+ return false;
222
+ }
223
+ }
224
+
225
+ var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
226
+ if (!gradioEl) {
227
+ gradioEl = document.querySelector('body > gradio-app');
228
+ }
229
+ var group1 = gradioEl.querySelectorAll('#group_1')[0];
230
+ var group2 = gradioEl.querySelectorAll('#group_2')[0];
231
+ var image_upload = gradioEl.querySelectorAll('#image_upload')[0];
232
+ var image_output = gradioEl.querySelectorAll('#image_output')[0];
233
+ var data_image = gradioEl.querySelectorAll('#image_upload [data-testid="image"]')[0];
234
+ var data_image_div = gradioEl.querySelectorAll('#image_upload [data-testid="image"] > div')[0];
235
+
236
+ if (isMobile()) {
237
+ var group1_width = group1.offsetWidth;
238
+ image_upload.setAttribute('style', 'width:' + (group1_width - 13*2) + 'px; min-height:none;');
239
+ data_image.setAttribute('style', 'width: ' + (group1_width - 14*2) + 'px;min-height:none;');
240
+ data_image_div.setAttribute('style', 'width: ' + (group1_width - 14*2) + 'px;min-height:none;');
241
+ image_output.setAttribute('style', 'width: ' + (group1_width - 13*2) + 'px;min-height:none;');
242
+ var enhancer = gradioEl.querySelectorAll('#enhancer-checkbox')[0];
243
+ enhancer.style.display = "none";
244
+ } else {
245
+ image_upload.setAttribute('style', 'min-height: 600px; overflow-x: overlay');
246
+ data_image.setAttribute('style', 'height: 600px');
247
+ data_image_div.setAttribute('style', 'min-height: 600px');
248
+ image_output.setAttribute('style', 'width: 600px');
249
+ }
250
+ group1.style.display = "none";
251
+ group2.style.display = "block";
252
+
253
+ }"""
254
 
255
  image_blocks = gr.Blocks(css=css)
256
  with image_blocks as demo:
257
+ with gr.Group(elem_id="group_1", visible=True) as group_1:
258
+ with gr.Box():
259
+ with gr.Row():
260
+ with gr.Column():
261
+ gallery = gr.Gallery(value=['./sample_00.jpg','./sample_00_e.jpg'], show_label=False)
262
+ gallery.style(grid=[2], width=320)
263
+ with gr.Row():
264
+ with gr.Column():
265
+ begin_button = gr.Button("Let's GO!", elem_id="begin-btn", visible=True)
266
+
267
+ with gr.Group(elem_id="group_2", visible=False) as group_2:
268
  with gr.Box(elem_id="work-container"):
269
  with gr.Row(elem_id="input-container"):
270
  with gr.Column():
271
  image = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}',
272
+ label="Upload(载入图片)", show_label=False).style(mobile_collapse=False)
273
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
274
  with gr.Column(elem_id="erase-container"):
275
  btn_erase = gr.Button(value = "Erase(擦除↓)",elem_id="erase_btn").style(
 
277
  rounded=(True, True, True, True),
278
  full_width=True,
279
  ).style(width=100)
280
+ with gr.Column(elem_id="enhancer-checkbox", visible=True if image_enhancer is not None else False):
281
  enhancer_label = 'Enhanced image(processing is very slow, please check only for blurred images)【增强图像(处理很慢,请仅针对模糊图像做勾选)】'
282
  img_enhancer = gr.Checkbox(label=enhancer_label).style(width=150)
283
  with gr.Row(elem_id="output-container"):
284
  with gr.Column():
285
+ image_out = gr.Image(elem_id="image_output",label="Result", show_label=False)
286
 
287
  btn_erase.click(fn=predict, inputs=[image, img_enhancer], outputs=[image_out])
288
+
289
+ begin_button.click(fn=None, inputs=[], outputs=[group_1, group_2], _js=set_page_elements)
290
+
291
  image_blocks.launch()