yizhangliu commited on
Commit
2e0fb71
·
1 Parent(s): 197179b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -140
app.py CHANGED
@@ -7,20 +7,10 @@ import uuid
7
  import torch
8
  from torch import autocast
9
  import cv2
10
-
11
  from io import BytesIO
12
- import requests
13
- import PIL
14
- from PIL import Image
15
- import numpy as np
16
- import os
17
- import uuid
18
- import torch
19
- from torch import autocast
20
- import cv2
21
  from matplotlib import pyplot as plt
22
  from torchvision import transforms
23
- from diffusers import DiffusionPipeline
24
 
25
  import io
26
  import logging
@@ -85,18 +75,17 @@ def read_content(file_path):
85
 
86
  model = None
87
 
88
- def model_process(image, mask, alpha_channel, ext):
89
  global model
 
90
  original_shape = image.shape
91
  interpolation = cv2.INTER_CUBIC
92
 
93
  size_limit = "Original"
94
- print(f'size_limit_2_ = {size_limit}')
95
  if size_limit == "Original":
96
  size_limit = max(image.shape)
97
  else:
98
  size_limit = int(size_limit)
99
- print(f'size_limit_3_ = {size_limit}')
100
 
101
  config = Config(
102
  ldm_steps=25,
@@ -122,108 +111,42 @@ def model_process(image, mask, alpha_channel, ext):
122
  cv2_radius=5,
123
  )
124
 
125
- print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
126
  if config.sd_seed == -1:
127
  config.sd_seed = random.randint(1, 999999999)
128
 
129
- logger.info(f"Origin image shape: {original_shape}")
130
- print(f"Origin image shape: {original_shape} / {image[250][250]}")
131
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
132
- logger.info(f"Resized image shape: {image.shape} / {type(image)}")
133
- print(f"Resized image shape: {image.shape} / {image[250][250]}")
134
-
135
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
136
- print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]} / {alpha_channel}")
137
 
138
  if model is None:
139
  return None
140
 
141
- start = time.time()
142
  res_np_img = model(image, mask, config)
143
- logger.info(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape}")
144
- print(f"process time_1_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype}")
145
-
146
  torch.cuda.empty_cache()
147
-
148
- alpha_channel = None
149
- if alpha_channel is not None:
150
- print(f"liuyz_here_10_: {alpha_channel.shape} / {res_np_img.dtype}")
151
- if alpha_channel.shape[:2] != res_np_img.shape[:2]:
152
- print(f"liuyz_here_20_: {res_np_img.shape}")
153
- alpha_channel = cv2.resize(
154
- alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
155
- )
156
- print(f"liuyz_here_30_: {res_np_img.dtype}")
157
- res_np_img = np.concatenate(
158
- (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
159
- )
160
- print(f"liuyz_here_40_: {res_np_img.dtype}")
161
-
162
- print(f"process time_2_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype} /{ext}")
163
 
164
- image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, ext)))
165
  return image # image
166
 
167
  model = ModelManager(
168
  name='lama',
169
  device=device,
170
- # hf_access_token=HF_TOKEN_SD,
171
- # sd_disable_nsfw=False,
172
- # sd_cpu_textencoder=True,
173
- # sd_run_local=True,
174
- # callback=diffuser_callback,
175
  )
176
 
177
- '''
178
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
179
-
180
- transform = transforms.Compose([
181
- transforms.ToTensor(),
182
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
183
- transforms.Resize((512, 512)),
184
- ])
185
- '''
186
-
187
-
188
  image_type = 'filepath' #'pil'
189
  def predict(input):
190
- print(f'liuyz_0_', input)
191
- '''
192
- image_np = np.array(input["image"])
193
- print(f'image_np = {image_np.shape}')
194
- mask_np = np.array(input["mask"])
195
- print(f'mask_np = {mask_np.shape}')
196
- '''
197
- '''
198
- image = dict["image"] # .convert("RGB") #.resize((512, 512))
199
- # target_size = (init_image.shape[0], init_image.shape[1])
200
- print(f'liuyz_1_', image.shape)
201
- print(f'liuyz_2_', image.convert("RGB").shape)
202
- print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
203
- # mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
204
- '''
205
  if image_type == 'filepath':
206
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
207
  origin_image_bytes = read_content(input["image"])
208
  print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
209
  image, _ = load_img(origin_image_bytes)
210
  mask, _ = load_img(read_content(input["mask"]), gray=True)
211
- alpha_channel = (np.ones((image.shape[0],image.shape[1]))*255).astype(np.uint8)
212
- ext = get_image_ext(origin_image_bytes)
213
-
214
- output = model_process(image, mask, alpha_channel, ext)
215
  elif image_type == 'pil':
216
  # input: {'image': pil, 'mask': pil}
217
  image_pil = input['image']
218
  mask_pil = input['mask']
219
-
220
  image = np.array(image_pil)
221
  mask = np.array(mask_pil.convert("L"))
222
- alpha_channel = (np.ones((image.shape[0],image.shape[1]))*255).astype(np.uint8)
223
- ext = 'png'
224
-
225
- output = model_process(image, mask, alpha_channel, ext)
226
- return output #, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
227
 
228
  css = '''
229
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
@@ -264,58 +187,14 @@ css = '''
264
  }
265
  '''
266
 
267
- '''
268
- sketchpad = Sketchpad()
269
- imageupload = ImageUplaod()
270
- interface = gr.Interface(fn=predict, inputs="image", outputs="image", sketchpad, imageupload)
271
-
272
- interface.launch(share=True)
273
- '''
274
-
275
- '''
276
- # gr.Interface(fn=predict, inputs="image", outputs="image").launch(share=True)
277
-
278
- image = gr.Image(source='upload', tool='sketch', type="pil", label="Upload")# .style(height=400)
279
- image_blocks = gr.Interface(
280
- fn=predict,
281
- inputs=image,
282
- outputs=image,
283
- # examples=[["cheetah.jpg"]],
284
- )
285
-
286
- image_blocks.launch(inline=True)
287
-
288
- import gradio as gr
289
-
290
- def greet(dict, name, is_morning, temperature):
291
- image = dict['image']
292
- target_size = (image.shape[0], image.shape[1])
293
- print(f'liuyz_1_', target_size)
294
- salutation = "Good morning" if is_morning else "Good evening"
295
- greeting = f"{salutation} {name}. It is {temperature} degrees today"
296
- celsius = (temperature - 32) * 5 / 9
297
- return image, greeting, round(celsius, 2)
298
-
299
- image = gr.Image(source='upload', tool='sketch', label="上传")# .style(height=400)
300
-
301
- demo = gr.Interface(
302
- fn=greet,
303
- inputs=[image, "text", "checkbox", gr.Slider(0, 100)],
304
- outputs=['image', "text", "number"],
305
- )
306
- demo.launch()
307
- '''
308
-
309
  image_blocks = gr.Blocks(css=css)
310
  with image_blocks as demo:
311
- # gr.HTML(read_content("header.html"))
312
  with gr.Group():
313
  with gr.Box():
314
  with gr.Row():
315
  with gr.Column():
316
  image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type=f'{image_type}', label="Upload").style(height=512)
317
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
318
- # prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
319
  btn_in = gr.Button("Done!").style(
320
  margin=True,
321
  rounded=(True, True, True, True),
@@ -324,18 +203,6 @@ with image_blocks as demo:
324
 
325
  with gr.Column():
326
  image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(height=512)
327
- '''
328
- with gr.Group(elem_id="share-btn-container"):
329
- community_icon = gr.HTML(community_icon_html, visible=False)
330
- loading_icon = gr.HTML(loading_icon_html, visible=False)
331
- share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
332
- '''
333
-
334
-
335
-
336
- # btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
337
- btn_in.click(fn=predict, inputs=[image], outputs=[image_out]) #, community_icon, loading_icon, share_button])
338
- #share_button.click(None, [], [], _js=share_js)
339
-
340
 
341
  image_blocks.launch()
 
7
  import torch
8
  from torch import autocast
9
  import cv2
 
10
  from io import BytesIO
11
+
 
 
 
 
 
 
 
 
12
  from matplotlib import pyplot as plt
13
  from torchvision import transforms
 
14
 
15
  import io
16
  import logging
 
75
 
76
  model = None
77
 
78
+ def model_process(image, mask):
79
  global model
80
+
81
  original_shape = image.shape
82
  interpolation = cv2.INTER_CUBIC
83
 
84
  size_limit = "Original"
 
85
  if size_limit == "Original":
86
  size_limit = max(image.shape)
87
  else:
88
  size_limit = int(size_limit)
 
89
 
90
  config = Config(
91
  ldm_steps=25,
 
111
  cv2_radius=5,
112
  )
113
 
 
114
  if config.sd_seed == -1:
115
  config.sd_seed = random.randint(1, 999999999)
116
 
 
 
117
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
 
 
 
118
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
 
119
 
120
  if model is None:
121
  return None
122
 
 
123
  res_np_img = model(image, mask, config)
 
 
 
124
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
127
  return image # image
128
 
129
  model = ModelManager(
130
  name='lama',
131
  device=device,
 
 
 
 
 
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
134
  image_type = 'filepath' #'pil'
135
  def predict(input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  if image_type == 'filepath':
137
  # input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
138
  origin_image_bytes = read_content(input["image"])
139
  print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
140
  image, _ = load_img(origin_image_bytes)
141
  mask, _ = load_img(read_content(input["mask"]), gray=True)
 
 
 
 
142
  elif image_type == 'pil':
143
  # input: {'image': pil, 'mask': pil}
144
  image_pil = input['image']
145
  mask_pil = input['mask']
 
146
  image = np.array(image_pil)
147
  mask = np.array(mask_pil.convert("L"))
148
+ output = model_process(image, mask)
149
+ return output
 
 
 
150
 
151
  css = '''
152
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
 
187
  }
188
  '''
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  image_blocks = gr.Blocks(css=css)
191
  with image_blocks as demo:
 
192
  with gr.Group():
193
  with gr.Box():
194
  with gr.Row():
195
  with gr.Column():
196
  image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type=f'{image_type}', label="Upload").style(height=512)
197
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
 
198
  btn_in = gr.Button("Done!").style(
199
  margin=True,
200
  rounded=(True, True, True, True),
 
203
 
204
  with gr.Column():
205
  image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(height=512)
206
+ btn_in.click(fn=predict, inputs=[image], outputs=[image_out])
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  image_blocks.launch()