Spaces:
Build error
Build error
import glob | |
import os | |
from PIL import Image | |
import shutil | |
import concurrent.futures | |
import gradio as gr | |
import cv2 | |
import re | |
import numpy as np | |
import torch | |
from lama_cleaner.helper import ( | |
norm_img, | |
get_cache_path_by_url, | |
load_jit_model, | |
) | |
from lama_cleaner.model.base import InpaintModel | |
from lama_cleaner.schema import Config | |
LAMA_MODEL_URL = os.environ.get( | |
"LAMA_MODEL_URL", | |
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", | |
) | |
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500") | |
class LaMa(InpaintModel): | |
name = "lama" | |
pad_mod = 8 | |
def init_model(self, device, **kwargs): | |
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() | |
def is_downloaded() -> bool: | |
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) | |
def forward(self, image, mask, config: Config): | |
"""Input image and output image have same size | |
image: [H, W, C] RGB | |
mask: [H, W] | |
return: BGR IMAGE | |
""" | |
image = norm_img(image) | |
mask = norm_img(mask) | |
mask = (mask > 0) * 1 | |
image = torch.from_numpy(image).unsqueeze(0).to(self.device) | |
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) | |
inpainted_image = self.model(image, mask) | |
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() | |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") | |
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) | |
return cur_res | |
lama_model = LaMa("cuda" if torch.cuda.is_available() else "cpu") | |
config = Config(hd_strategy_crop_margin=196, ldm_steps=25, hd_strategy='Original', hd_strategy_crop_trigger_size=1280, hd_strategy_resize_limit=2048) | |
def remove_image_watermark(inputs): | |
alpha_channel = None | |
image, mask = inputs["image"], inputs["mask"] | |
if image.mode == "RGBA": | |
image = np.array(image) | |
alpha_channel = image[:, :, -1] | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
else: | |
image = np.array(image) | |
mask = cv2.threshold(np.array(mask.convert("L")), 127, 255, cv2.THRESH_BINARY)[1] | |
output = lama_model(image, mask, config) | |
output = cv2.cvtColor(output.astype(np.uint8), cv2.COLOR_BGR2RGB) | |
if alpha_channel is not None: | |
if alpha_channel.shape[:2] != output.shape[:2]: | |
alpha_channel = cv2.resize( | |
alpha_channel, dsize=(output.shape[1], output.shape[0]) | |
) | |
output = np.concatenate( | |
(output, alpha_channel[:, :, np.newaxis]), axis=-1 | |
) | |
return Image.fromarray(output) | |
def process_image(mask_data, image_path): | |
output = remove_image_watermark({"image": Image.open(image_path), "mask": mask_data}) | |
output_image_path = os.path.join('output_images', os.path.splitext(os.path.basename(image_path))[0] + '_inpainted' + os.path.splitext(image_path)[1]) | |
output.save(output_image_path) | |
return output_image_path | |
def remove_video_watermark(sketch, images_path='frames', output_path='output_images'): | |
if os.path.exists('output_images'): | |
shutil.rmtree('output_images') | |
os.makedirs('output_images') | |
image_paths = glob.glob(f'{images_path}/*.*') | |
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: | |
executor.map(lambda image_path: process_image(sketch["mask"], image_path), image_paths) | |
return gr.File.update(value=convert_frames_to_video('output_images'), visible=True), gr.Button.update(value='Done!') | |
def convert_video_to_frames(video): | |
if os.path.exists('input_video.mp4'): | |
os.remove('input_video.mp4') | |
# save the video to the current directory from temporary file | |
with open(video, 'rb') as f: | |
with open('input_video.mp4', 'wb') as f2: | |
f2.write(f.read()) | |
# os.system(f"ffmpeg -i {video} input_video.mp4") | |
video_path = 'input_video.mp4' | |
if os.path.exists('frames'): | |
shutil.rmtree('frames') | |
os.makedirs('frames') | |
video_name = os.path.splitext(os.path.basename(video_path))[0] | |
vidcap = cv2.VideoCapture(video_path) | |
success, image = vidcap.read() | |
count = 1 | |
while success: | |
cv2.imwrite(f"frames/{video_name}_{count}.jpg", image) | |
success, image = vidcap.read() | |
count += 1 | |
return gr.Image.update(value=f"{os.getcwd()}/frames/{video_name}_1.jpg", interactive=True), gr.Button.update(interactive=True) | |
def convert_frames_to_video(frames_path): | |
if os.path.exists('output_video.mp4'): | |
os.remove('output_video.mp4') | |
img_array = [] | |
filelist = glob.glob(f"{frames_path}/*.jpg") | |
# Sort frames by number | |
frame_numbers = [int(re.findall(r'\d+', os.path.basename(frame))[0]) for frame in filelist] | |
sorted_frames = [frame for _, frame in sorted(zip(frame_numbers, filelist), key=lambda pair: pair[0])] | |
for filename in sorted_frames: | |
img = cv2.imread(filename) | |
height, width, layers = img.shape | |
size = (width, height) | |
img_array.append(img) | |
out = cv2.VideoWriter('output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, size) | |
for i in range(len(img_array)): | |
out.write(img_array[i]) | |
out.release() | |
return 'output_video.mp4' |