rishi9440's picture
Duplicate from aryadytm/remove-photo-background
da406d5
raw
history blame contribute delete
No virus
2.98 kB
# Credits to https://github.com/ZHKKKe/MODNet for the model.
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image, ImageColor
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from src.models.modnet import MODNet
from src.st_style import apply_prod_style
# apply(st)
MODEL = "./assets/modnet_photographic_portrait_matting.ckpt"
def change_background(image, matte, background_alpha: float=1.0, background_hex: str="#000000"):
"""
image: PIL Image (RGBA)
matte: PIL Image (grayscale, if 255 it is foreground)
background_alpha: float
background_hex: string
"""
img = deepcopy(image)
if image.mode != "RGBA":
img = img.convert("RGBA")
background_color = ImageColor.getrgb(background_hex)
background_alpha = int(255 * background_alpha)
background = Image.new("RGBA", img.size, color=background_color + (background_alpha,))
background.paste(img, mask=matte)
return background
def matte(image):
# define hyper-parameters
ref_size = 512
# define image to tensor transform
im_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
if torch.cuda.is_available():
modnet = modnet.cuda()
weights = torch.load(MODEL)
else:
weights = torch.load(MODEL, map_location=torch.device('cpu'))
modnet.load_state_dict(weights)
modnet.eval()
# read image
im = deepcopy(image)
# unify image channels to 3
im = np.asarray(im)
if len(im.shape) == 2:
im = im[:, :, None]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
elif im.shape[2] == 4:
im = im[:, :, 0:3]
# convert image to PyTorch tensor
im = Image.fromarray(im)
im = im_transform(im)
# add mini-batch dim
im = im[None, :, :, :]
# resize image for input
im_b, im_c, im_h, im_w = im.shape
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
return Image.fromarray(((matte * 255).astype('uint8')), mode='L')