ImgCleaner / app.py
yizhangliu's picture
Update app.py
811f999
raw
history blame
16.5 kB
import gradio as gr
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline
import io
import logging
import multiprocessing
import random
import time
import imghdr
from pathlib import Path
from typing import Union
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
os.environ["TORCH_HOME"] = './'
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
from share_btn import community_icon_html, loading_icon_html, share_js
HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device = {device}')
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def diffuser_callback(i, t, latents):
pass
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
def load_img_1_(nparr, gray: bool = False):
# alpha_channel = None
# nparr = np.frombuffer(img_bytes, np.uint8)
if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
else:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
return np_img, alpha_channel
model = None
def model_process(input):
global model
# input = request.files
# RGB
# origin_image_bytes = input["image"].read()
image = input['image']
mask = input['mask']
print(f'liuyz_2_here_', type(image), image.shape)
image_pil = Image.fromarray(image)
mask_pil = Image.fromarray(mask).convert("L")
print(f'image_pil_ = {type(image_pil)}')
print(f'mask_pil_ = {type(mask_pil)}')
mask_pil.save(f'./mask_pil.png')
#image, alpha_channel = load_img(image)
# Origin image shape: (512, 512, 3)
alpha_channel = np.ones((image.shape[0],image.shape[1]))*255
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
# form = request.form
print(f'liuyz_3_here_', original_shape, alpha_channel)
size_limit = "Original" # image.shape[1] # : Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
# print(f'config = {config}')
print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
# logger.info(f"Origin image shape: {original_shape}")
print(f"Origin image shape: {original_shape} / {image[250][250]}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
# logger.info(f"Resized image shape: {image.shape}")
print(f"Resized image shape: {image.shape} / {image[250][250]}")
# mask, _ = load_img(mask, gray=True)
mask = np.array(mask_pil)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]}")
if model is None:
return None
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape}")
print(f"process time_1_: {(time.time() - start) * 1000}ms, {alpha_channel.shape}, {res_np_img.shape} / {res_np_img[250][250]}")
torch.cuda.empty_cache()
if alpha_channel is not None:
print(f"liuyz_here_10_: {alpha_channel.shape}")
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
print(f"liuyz_here_20_: {alpha_channel.shape} / {res_np_img.shape}")
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
print(f"liuyz_here_30_: {alpha_channel.shape} / {res_np_img.shape}")
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
print(f"process time_2_: {(time.time() - start) * 1000}ms, {alpha_channel.shape}, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype}")
image = Image.fromarray(res_np_img)
image.save(f'./result_image.png')
return image
'''
ext = get_image_ext(origin_image_bytes)
return ext
'''
def model_process_2(input): #image, mask):
global model
# {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
# input = request.files
# RGB
origin_image_bytes = read_content(input["image"])
print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
# form = request.form
# print(f'size_limit_1_ = ', form["sizeLimit"], type(input["image"]))
size_limit = "Original" #: Union[int, str] = form.get("sizeLimit", "1080")
print(f'size_limit_2_ = {size_limit}')
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f'size_limit_3_ = {size_limit}')
config = Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
logger.info(f"Origin image shape: {original_shape}")
print(f"Origin image shape: {original_shape} / {image[250][250]}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape} / {type(image)}")
print(f"Resized image shape: {image.shape} / {image[250][250]}")
mask, alpha_channel = load_img(read_content(input["mask"]), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]} / {alpha_channel}")
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
print(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]}")
torch.cuda.empty_cache()
if alpha_channel is not None:
print(f"liuyz_here_1_: {alpha_channel}")
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
image = Image.fromarray(res_np_img)
image.save(f'./result_image.png')
return image
'''
ext = get_image_ext(origin_image_bytes)
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
)
response.headers["X-Seed"] = str(config.sd_seed)
return response
'''
model = ModelManager(
name='lama',
device=device,
# hf_access_token=HF_TOKEN_SD,
# sd_disable_nsfw=False,
# sd_cpu_textencoder=True,
# sd_run_local=True,
# callback=diffuser_callback,
)
'''
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
'''
def read_content(file_path):
"""read the content of target file
"""
with open(file_path, 'rb') as f:
content = f.read()
return content
def predict(input):
print(f'liuyz_0_', input)
'''
image_np = np.array(input["image"])
print(f'image_np = {image_np.shape}')
mask_np = np.array(input["mask"])
print(f'mask_np = {mask_np.shape}')
'''
'''
image = dict["image"] # .convert("RGB") #.resize((512, 512))
# target_size = (init_image.shape[0], init_image.shape[1])
print(f'liuyz_1_', image.shape)
print(f'liuyz_2_', image.convert("RGB").shape)
print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
# mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
'''
output = model_process(input) # dict["image"], dict["mask"])
# output = mask #output.images[0]
# output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
output = input["mask"]
return output #, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
print(f'liuyz_500_here_')
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:512px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
'''
sketchpad = Sketchpad()
imageupload = ImageUplaod()
interface = gr.Interface(fn=predict, inputs="image", outputs="image", sketchpad, imageupload)
interface.launch(share=True)
'''
'''
# gr.Interface(fn=predict, inputs="image", outputs="image").launch(share=True)
image = gr.Image(source='upload', tool='sketch', type="pil", label="Upload")# .style(height=400)
image_blocks = gr.Interface(
fn=predict,
inputs=image,
outputs=image,
# examples=[["cheetah.jpg"]],
)
image_blocks.launch(inline=True)
import gradio as gr
def greet(dict, name, is_morning, temperature):
image = dict['image']
target_size = (image.shape[0], image.shape[1])
print(f'liuyz_1_', target_size)
salutation = "Good morning" if is_morning else "Good evening"
greeting = f"{salutation} {name}. It is {temperature} degrees today"
celsius = (temperature - 32) * 5 / 9
return image, greeting, round(celsius, 2)
image = gr.Image(source='upload', tool='sketch', label="上传")# .style(height=400)
demo = gr.Interface(
fn=greet,
inputs=[image, "text", "checkbox", gr.Slider(0, 100)],
outputs=['image', "text", "number"],
)
demo.launch()
'''
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
# gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', tool='sketch',label="Upload").style(height=512)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
# prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
btn = gr.Button("Done!").style(
margin=True,
rounded=(True, True, True, True),
full_width=True,
)
with gr.Column():
image_out = gr.Image(label="Output").style(height=512)
'''
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=False)
loading_icon = gr.HTML(loading_icon_html, visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
'''
# btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
btn.click(fn=predict, inputs=[image], outputs=[image_out]) #, community_icon, loading_icon, share_button])
#share_button.click(None, [], [], _js=share_js)
image_blocks.launch()