ImgCleaner / app.py
yizhangliu's picture
Update app.py
c5e6841
raw
history blame
11.9 kB
import gradio as gr
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline
import io
import logging
import multiprocessing
import random
import time
import imghdr
from pathlib import Path
from typing import Union
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
from share_btn import community_icon_html, loading_icon_html, share_js
HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device = {device}')
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def read_content(file_path):
"""read the content of target file
"""
with open(file_path, 'rb') as f:
content = f.read()
return content
model = None
def model_process(image, mask, alpha_channel, ext):
global model
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit = "Original"
print(f'size_limit_2_ = {size_limit}')
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f'size_limit_3_ = {size_limit}')
config = Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
logger.info(f"Origin image shape: {original_shape}")
print(f"Origin image shape: {original_shape} / {image[250][250]}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape} / {type(image)}")
print(f"Resized image shape: {image.shape} / {image[250][250]}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]} / {alpha_channel}")
if model is None:
return None
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape}")
print(f"process time_1_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype}")
torch.cuda.empty_cache()
alpha_channel = None
if alpha_channel is not None:
print(f"liuyz_here_10_: {alpha_channel.shape} / {res_np_img.dtype}")
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
print(f"liuyz_here_20_: {res_np_img.shape}")
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
print(f"liuyz_here_30_: {res_np_img.dtype}")
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
print(f"liuyz_here_40_: {res_np_img.dtype}")
print(f"process time_2_: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]} / {res_np_img.dtype} /{ext}")
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, ext)))
return image # image
model = ModelManager(
name='lama',
device=device,
# hf_access_token=HF_TOKEN_SD,
# sd_disable_nsfw=False,
# sd_cpu_textencoder=True,
# sd_run_local=True,
# callback=diffuser_callback,
)
'''
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
'''
image_type = 'filepath' #'pil'
def predict(input):
print(f'liuyz_0_', input)
'''
image_np = np.array(input["image"])
print(f'image_np = {image_np.shape}')
mask_np = np.array(input["mask"])
print(f'mask_np = {mask_np.shape}')
'''
'''
image = dict["image"] # .convert("RGB") #.resize((512, 512))
# target_size = (init_image.shape[0], init_image.shape[1])
print(f'liuyz_1_', image.shape)
print(f'liuyz_2_', image.convert("RGB").shape)
print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
# mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
'''
if image_type == 'filepath':
# input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
origin_image_bytes = read_content(input["image"])
print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
image, _ = load_img(origin_image_bytes)
mask, _ = load_img(read_content(input["mask"]), gray=True)
alpha_channel = (np.ones((image.shape[0],image.shape[1]))*255).astype(np.uint8)
ext = get_image_ext(origin_image_bytes)
output = model_process(image, mask, alpha_channel, ext)
elif image_type == 'pil':
# input: {'image': pil, 'mask': pil}
image_pil = input['image']
mask_pil = input['mask']
image = np.array(image_pil)
mask = np.array(mask_pil.convert("L"))
alpha_channel = (np.ones((image.shape[0],image.shape[1]))*255).astype(np.uint8)
ext = 'png'
output = model_process(image, mask, alpha_channel, ext)
return output #, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:512px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
'''
sketchpad = Sketchpad()
imageupload = ImageUplaod()
interface = gr.Interface(fn=predict, inputs="image", outputs="image", sketchpad, imageupload)
interface.launch(share=True)
'''
'''
# gr.Interface(fn=predict, inputs="image", outputs="image").launch(share=True)
image = gr.Image(source='upload', tool='sketch', type="pil", label="Upload")# .style(height=400)
image_blocks = gr.Interface(
fn=predict,
inputs=image,
outputs=image,
# examples=[["cheetah.jpg"]],
)
image_blocks.launch(inline=True)
import gradio as gr
def greet(dict, name, is_morning, temperature):
image = dict['image']
target_size = (image.shape[0], image.shape[1])
print(f'liuyz_1_', target_size)
salutation = "Good morning" if is_morning else "Good evening"
greeting = f"{salutation} {name}. It is {temperature} degrees today"
celsius = (temperature - 32) * 5 / 9
return image, greeting, round(celsius, 2)
image = gr.Image(source='upload', tool='sketch', label="上传")# .style(height=400)
demo = gr.Interface(
fn=greet,
inputs=[image, "text", "checkbox", gr.Slider(0, 100)],
outputs=['image', "text", "number"],
)
demo.launch()
'''
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
# gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', elem_id="image_upload", tool='sketch,editor', type=f'{image_type}', label="Upload").style(height=512)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
# prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
btn_in = gr.Button("Done!").style(
margin=True,
rounded=(True, True, True, True),
full_width=True,
)
with gr.Column():
image_out = gr.Image(label="Output", elem_id="image_output", visible=True).style(height=512)
'''
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=False)
loading_icon = gr.HTML(loading_icon_html, visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
'''
# btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
btn_in.click(fn=predict, inputs=[image], outputs=[image_out]) #, community_icon, loading_icon, share_button])
#share_button.click(None, [], [], _js=share_js)
image_blocks.launch()