Spaces:
Running
Running
yizhangliu
commited on
Commit
·
2e0fb71
1
Parent(s):
197179b
Update app.py
Browse files
app.py
CHANGED
@@ -7,20 +7,10 @@ import uuid
|
|
7 |
import torch
|
8 |
from torch import autocast
|
9 |
import cv2
|
10 |
-
|
11 |
from io import BytesIO
|
12 |
-
|
13 |
-
import PIL
|
14 |
-
from PIL import Image
|
15 |
-
import numpy as np
|
16 |
-
import os
|
17 |
-
import uuid
|
18 |
-
import torch
|
19 |
-
from torch import autocast
|
20 |
-
import cv2
|
21 |
from matplotlib import pyplot as plt
|
22 |
from torchvision import transforms
|
23 |
-
from diffusers import DiffusionPipeline
|
24 |
|
25 |
import io
|
26 |
import logging
|
@@ -85,18 +75,17 @@ def read_content(file_path):
|
|
85 |
|
86 |
model = None
|
87 |
|
88 |
-
def model_process(image, mask
|
89 |
global model
|
|
|
90 |
original_shape = image.shape
|
91 |
interpolation = cv2.INTER_CUBIC
|
92 |
|
93 |
size_limit = "Original"
|
94 |
-
print(f'size_limit_2_ = {size_limit}')
|
95 |
if size_limit == "Original":
|
96 |
size_limit = max(image.shape)
|
97 |
else:
|
98 |
size_limit = int(size_limit)
|
99 |
-
print(f'size_limit_3_ = {size_limit}')
|
100 |
|
101 |
config = Config(
|
102 |
ldm_steps=25,
|
@@ -122,108 +111,42 @@ def model_process(image, mask, alpha_channel, ext):
|
|
122 |
cv2_radius=5,
|
123 |
)
|
124 |
|
125 |
-
print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
|
126 |
if config.sd_seed == -1:
|
127 |
config.sd_seed = random.randint(1, 999999999)
|
128 |
|
129 |
-
logger.info(f"Origin image shape: {original_shape}")
|
130 |
-
print(f"Origin image shape: {original_shape} / {image[250][250]}")
|
131 |
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
132 |
-
logger.info(f"Resized image shape: {image.shape} / {type(image)}")
|
133 |
-
print(f"Resized image shape: {image.shape} / {image[250][250]}")
|
134 |
-
|
135 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
136 |
-
print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]} / {alpha_channel}")
|
137 |
|
138 |
if model is None:
|
139 |
return None
|
140 |
|
141 |
-
start = time.time()
|
142 |
res_np_img = model(image, mask, config)
|
143 |
-
logger.info(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape}")
|
144 |
-
print(f"process time_1_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype}")
|
145 |
-
|
146 |
torch.cuda.empty_cache()
|
147 |
-
|
148 |
-
alpha_channel = None
|
149 |
-
if alpha_channel is not None:
|
150 |
-
print(f"liuyz_here_10_: {alpha_channel.shape} / {res_np_img.dtype}")
|
151 |
-
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
152 |
-
print(f"liuyz_here_20_: {res_np_img.shape}")
|
153 |
-
alpha_channel = cv2.resize(
|
154 |
-
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
155 |
-
)
|
156 |
-
print(f"liuyz_here_30_: {res_np_img.dtype}")
|
157 |
-
res_np_img = np.concatenate(
|
158 |
-
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
159 |
-
)
|
160 |
-
print(f"liuyz_here_40_: {res_np_img.dtype}")
|
161 |
-
|
162 |
-
print(f"process time_2_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype} /{ext}")
|
163 |
|
164 |
-
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img,
|
165 |
return image # image
|
166 |
|
167 |
model = ModelManager(
|
168 |
name='lama',
|
169 |
device=device,
|
170 |
-
# hf_access_token=HF_TOKEN_SD,
|
171 |
-
# sd_disable_nsfw=False,
|
172 |
-
# sd_cpu_textencoder=True,
|
173 |
-
# sd_run_local=True,
|
174 |
-
# callback=diffuser_callback,
|
175 |
)
|
176 |
|
177 |
-
'''
|
178 |
-
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
|
179 |
-
|
180 |
-
transform = transforms.Compose([
|
181 |
-
transforms.ToTensor(),
|
182 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
183 |
-
transforms.Resize((512, 512)),
|
184 |
-
])
|
185 |
-
'''
|
186 |
-
|
187 |
-
|
188 |
image_type = 'filepath' #'pil'
|
189 |
def predict(input):
|
190 |
-
print(f'liuyz_0_', input)
|
191 |
-
'''
|
192 |
-
image_np = np.array(input["image"])
|
193 |
-
print(f'image_np = {image_np.shape}')
|
194 |
-
mask_np = np.array(input["mask"])
|
195 |
-
print(f'mask_np = {mask_np.shape}')
|
196 |
-
'''
|
197 |
-
'''
|
198 |
-
image = dict["image"] # .convert("RGB") #.resize((512, 512))
|
199 |
-
# target_size = (init_image.shape[0], init_image.shape[1])
|
200 |
-
print(f'liuyz_1_', image.shape)
|
201 |
-
print(f'liuyz_2_', image.convert("RGB").shape)
|
202 |
-
print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
|
203 |
-
# mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
|
204 |
-
'''
|
205 |
if image_type == 'filepath':
|
206 |
# input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
|
207 |
origin_image_bytes = read_content(input["image"])
|
208 |
print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
|
209 |
image, _ = load_img(origin_image_bytes)
|
210 |
mask, _ = load_img(read_content(input["mask"]), gray=True)
|
211 |
-
alpha_channel = (np.ones((image.shape[0],image.shape[1]))*255).astype(np.uint8)
|
212 |
-
ext = get_image_ext(origin_image_bytes)
|
213 |
-
|
214 |
-
output = model_process(image, mask, alpha_channel, ext)
|
215 |
elif image_type == 'pil':
|
216 |
# input: {'image': pil, 'mask': pil}
|
217 |
image_pil = input['image']
|
218 |
mask_pil = input['mask']
|
219 |
-
|
220 |
image = np.array(image_pil)
|
221 |
mask = np.array(mask_pil.convert("L"))
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
output = model_process(image, mask, alpha_channel, ext)
|
226 |
-
return output #, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
227 |
|
228 |
css = '''
|
229 |
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
@@ -264,58 +187,14 @@ css = '''
|
|
264 |
}
|
265 |
'''
|
266 |
|
267 |
-
'''
|
268 |
-
sketchpad = Sketchpad()
|
269 |
-
imageupload = ImageUplaod()
|
270 |
-
interface = gr.Interface(fn=predict, inputs="image", outputs="image", sketchpad, imageupload)
|
271 |
-
|
272 |
-
interface.launch(share=True)
|
273 |
-
'''
|
274 |
-
|
275 |
-
'''
|
276 |
-
# gr.Interface(fn=predict, inputs="image", outputs="image").launch(share=True)
|
277 |
-
|
278 |
-
image = gr.Image(source='upload', tool='sketch', type="pil", label="Upload")# .style(height=400)
|
279 |
-
image_blocks = gr.Interface(
|
280 |
-
fn=predict,
|
281 |
-
inputs=image,
|
282 |
-
outputs=image,
|
283 |
-
# examples=[["cheetah.jpg"]],
|
284 |
-
)
|
285 |
-
|
286 |
-
image_blocks.launch(inline=True)
|
287 |
-
|
288 |
-
import gradio as gr
|
289 |
-
|
290 |
-
def greet(dict, name, is_morning, temperature):
|
291 |
-
image = dict['image']
|
292 |
-
target_size = (image.shape[0], image.shape[1])
|
293 |
-
print(f'liuyz_1_', target_size)
|
294 |
-
salutation = "Good morning" if is_morning else "Good evening"
|
295 |
-
greeting = f"{salutation} {name}. It is {temperature} degrees today"
|
296 |
-
celsius = (temperature - 32) * 5 / 9
|
297 |
-
return image, greeting, round(celsius, 2)
|
298 |
-
|
299 |
-
image = gr.Image(source='upload', tool='sketch', label="上传")# .style(height=400)
|
300 |
-
|
301 |
-
demo = gr.Interface(
|
302 |
-
fn=greet,
|
303 |
-
inputs=[image, "text", "checkbox", gr.Slider(0, 100)],
|
304 |
-
outputs=['image', "text", "number"],
|
305 |
-
)
|
306 |
-
demo.launch()
|
307 |
-
'''
|
308 |
-
|
309 |
image_blocks = gr.Blocks(css=css)
|
310 |
with image_blocks as demo:
|
311 |
-
# gr.HTML(read_content("header.html"))
|
312 |
with gr.Group():
|
313 |
with gr.Box():
|
314 |
with gr.Row():
|
315 |
with gr.Column():
|
316 |
image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type=f'{image_type}', label="Upload").style(height=512)
|
317 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
318 |
-
# prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
|
319 |
btn_in = gr.Button("Done!").style(
|
320 |
margin=True,
|
321 |
rounded=(True, True, True, True),
|
@@ -324,18 +203,6 @@ with image_blocks as demo:
|
|
324 |
|
325 |
with gr.Column():
|
326 |
image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(height=512)
|
327 |
-
|
328 |
-
with gr.Group(elem_id="share-btn-container"):
|
329 |
-
community_icon = gr.HTML(community_icon_html, visible=False)
|
330 |
-
loading_icon = gr.HTML(loading_icon_html, visible=False)
|
331 |
-
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
332 |
-
'''
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
# btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
|
337 |
-
btn_in.click(fn=predict, inputs=[image], outputs=[image_out]) #, community_icon, loading_icon, share_button])
|
338 |
-
#share_button.click(None, [], [], _js=share_js)
|
339 |
-
|
340 |
|
341 |
image_blocks.launch()
|
|
|
7 |
import torch
|
8 |
from torch import autocast
|
9 |
import cv2
|
|
|
10 |
from io import BytesIO
|
11 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from matplotlib import pyplot as plt
|
13 |
from torchvision import transforms
|
|
|
14 |
|
15 |
import io
|
16 |
import logging
|
|
|
75 |
|
76 |
model = None
|
77 |
|
78 |
+
def model_process(image, mask):
|
79 |
global model
|
80 |
+
|
81 |
original_shape = image.shape
|
82 |
interpolation = cv2.INTER_CUBIC
|
83 |
|
84 |
size_limit = "Original"
|
|
|
85 |
if size_limit == "Original":
|
86 |
size_limit = max(image.shape)
|
87 |
else:
|
88 |
size_limit = int(size_limit)
|
|
|
89 |
|
90 |
config = Config(
|
91 |
ldm_steps=25,
|
|
|
111 |
cv2_radius=5,
|
112 |
)
|
113 |
|
|
|
114 |
if config.sd_seed == -1:
|
115 |
config.sd_seed = random.randint(1, 999999999)
|
116 |
|
|
|
|
|
117 |
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
|
|
|
|
|
|
118 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
|
|
119 |
|
120 |
if model is None:
|
121 |
return None
|
122 |
|
|
|
123 |
res_np_img = model(image, mask, config)
|
|
|
|
|
|
|
124 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
127 |
return image # image
|
128 |
|
129 |
model = ModelManager(
|
130 |
name='lama',
|
131 |
device=device,
|
|
|
|
|
|
|
|
|
|
|
132 |
)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
image_type = 'filepath' #'pil'
|
135 |
def predict(input):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
if image_type == 'filepath':
|
137 |
# input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
|
138 |
origin_image_bytes = read_content(input["image"])
|
139 |
print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
|
140 |
image, _ = load_img(origin_image_bytes)
|
141 |
mask, _ = load_img(read_content(input["mask"]), gray=True)
|
|
|
|
|
|
|
|
|
142 |
elif image_type == 'pil':
|
143 |
# input: {'image': pil, 'mask': pil}
|
144 |
image_pil = input['image']
|
145 |
mask_pil = input['mask']
|
|
|
146 |
image = np.array(image_pil)
|
147 |
mask = np.array(mask_pil.convert("L"))
|
148 |
+
output = model_process(image, mask)
|
149 |
+
return output
|
|
|
|
|
|
|
150 |
|
151 |
css = '''
|
152 |
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
|
|
187 |
}
|
188 |
'''
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
image_blocks = gr.Blocks(css=css)
|
191 |
with image_blocks as demo:
|
|
|
192 |
with gr.Group():
|
193 |
with gr.Box():
|
194 |
with gr.Row():
|
195 |
with gr.Column():
|
196 |
image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type=f'{image_type}', label="Upload").style(height=512)
|
197 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
|
|
198 |
btn_in = gr.Button("Done!").style(
|
199 |
margin=True,
|
200 |
rounded=(True, True, True, True),
|
|
|
203 |
|
204 |
with gr.Column():
|
205 |
image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(height=512)
|
206 |
+
btn_in.click(fn=predict, inputs=[image], outputs=[image_out])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
image_blocks.launch()
|