ImgCleaner / app.py
yizhangliu's picture
Update app.py
4a29657
raw
history blame
11.9 kB
import gradio as gr
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from io import BytesIO
import requests
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from diffusers import DiffusionPipeline
import io
import logging
import multiprocessing
import random
import time
import imghdr
from pathlib import Path
from typing import Union
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
os.environ["TORCH_HOME"] = './'
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
from share_btn import community_icon_html, loading_icon_html, share_js
HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device = {device}')
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def diffuser_callback(i, t, latents):
pass
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
model = None
def model_process(image, mask):
global model
# input = request.files
# RGB
# origin_image_bytes = input["image"].read()
print(f'liuyz_2_here_')
# image, alpha_channel = load_img(origin_image_bytes)
# Origin image shape: (512, 512, 3)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
# form = request.form
print(f'liuyz_3_here_', original_shape)
size_limit = 1080 # : Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
# print(f'config = {config}')
print(f'config/alpha_channel/size_limit = {config} / {alpha_channel} / {size_limit}')
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
# logger.info(f"Origin image shape: {original_shape}")
print(f"Origin image shape: {original_shape} / {image[250][250]}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
# logger.info(f"Resized image shape: {image.shape}")
print(f"Resized image shape: {image.shape} / {image[250][250]}")
#mask, _ = load_img(input["mask"].read(), gray=True)
mask = np.array(Image.fromarray(mask).convert("L"))
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]}")
if model is None:
return None
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape}")
print(f"process time: {(time.time() - start) * 1000}ms, {res_np_img.shape} / {res_np_img[250][250]}")
torch.cuda.empty_cache()
return Image.fromarray(res_np_img)
'''
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
return ext
'''
model = ModelManager(
name='lama',
device=device,
# hf_access_token=HF_TOKEN_SD,
# sd_disable_nsfw=False,
# sd_cpu_textencoder=True,
# sd_run_local=True,
# callback=diffuser_callback,
)
'''
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
'''
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
def predict(dict):
# print(f'liuyz_0_', dict)
'''
image = dict["image"] # .convert("RGB") #.resize((512, 512))
# target_size = (init_image.shape[0], init_image.shape[1])
print(f'liuyz_1_', image.shape)
print(f'liuyz_2_', image.convert("RGB").shape)
print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
# mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
'''
image = Image.fromarray(dict["image"])
mask = np.array(Image.fromarray(dict["mask"]).convert("L"))
print(f'mask___1 = {mask.shape}')
output = model_process(dict["image"], dict["mask"])
# output = mask #output.images[0]
# output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
return output #, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
print(f'liuyz_500_here_')
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:512px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
'''
sketchpad = Sketchpad()
imageupload = ImageUplaod()
interface = gr.Interface(fn=predict, inputs="image", outputs="image", sketchpad, imageupload)
interface.launch(share=True)
'''
'''
# gr.Interface(fn=predict, inputs="image", outputs="image").launch(share=True)
image = gr.Image(source='upload', tool='sketch', type="pil", label="Upload")# .style(height=400)
image_blocks = gr.Interface(
fn=predict,
inputs=image,
outputs=image,
# examples=[["cheetah.jpg"]],
)
image_blocks.launch(inline=True)
import gradio as gr
def greet(dict, name, is_morning, temperature):
image = dict['image']
target_size = (image.shape[0], image.shape[1])
print(f'liuyz_1_', target_size)
salutation = "Good morning" if is_morning else "Good evening"
greeting = f"{salutation} {name}. It is {temperature} degrees today"
celsius = (temperature - 32) * 5 / 9
return image, greeting, round(celsius, 2)
image = gr.Image(source='upload', tool='sketch', label="上传")# .style(height=400)
demo = gr.Interface(
fn=greet,
inputs=[image, "text", "checkbox", gr.Slider(0, 100)],
outputs=['image', "text", "number"],
)
demo.launch()
'''
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
# gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", label="Upload").style(height=512)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
# prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
btn = gr.Button("Done!").style(
margin=True,
rounded=(True, True, True, True),
full_width=True,
)
with gr.Column():
image_out = gr.Image(label="Output").style(height=512)
'''
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=False)
loading_icon = gr.HTML(loading_icon_html, visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
'''
# btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
btn.click(fn=predict, inputs=[image], outputs=[image_out]) #, community_icon, loading_icon, share_button])
#share_button.click(None, [], [], _js=share_js)
image_blocks.launch()