# Credits to https://github.com/ZHKKKe/MODNet for the model. import streamlit as st import numpy as np import matplotlib.pyplot as plt import time import os from PIL import Image, ImageColor from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from src.models.modnet import MODNet from src.st_style import apply_prod_style # apply(st) MODEL = "./assets/modnet_photographic_portrait_matting.ckpt" def change_background(image, matte, background_alpha: float=1.0, background_hex: str="#000000"): """ image: PIL Image (RGBA) matte: PIL Image (grayscale, if 255 it is foreground) background_alpha: float background_hex: string """ img = deepcopy(image) if image.mode != "RGBA": img = img.convert("RGBA") background_color = ImageColor.getrgb(background_hex) background_alpha = int(255 * background_alpha) background = Image.new("RGBA", img.size, color=background_color + (background_alpha,)) background.paste(img, mask=matte) return background def matte(image): # define hyper-parameters ref_size = 512 # define image to tensor transform im_transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] ) # create MODNet and load the pre-trained ckpt modnet = MODNet(backbone_pretrained=False) modnet = nn.DataParallel(modnet) if torch.cuda.is_available(): modnet = modnet.cuda() weights = torch.load(MODEL) else: weights = torch.load(MODEL, map_location=torch.device('cpu')) modnet.load_state_dict(weights) modnet.eval() # read image im = deepcopy(image) # unify image channels to 3 im = np.asarray(im) if len(im.shape) == 2: im = im[:, :, None] if im.shape[2] == 1: im = np.repeat(im, 3, axis=2) elif im.shape[2] == 4: im = im[:, :, 0:3] # convert image to PyTorch tensor im = Image.fromarray(im) im = im_transform(im) # add mini-batch dim im = im[None, :, :, :] # resize image for input im_b, im_c, im_h, im_w = im.shape if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: if im_w >= im_h: im_rh = ref_size im_rw = int(im_w / im_h * ref_size) elif im_w < im_h: im_rw = ref_size im_rh = int(im_h / im_w * ref_size) else: im_rh = im_h im_rw = im_w im_rw = im_rw - im_rw % 32 im_rh = im_rh - im_rh % 32 im = F.interpolate(im, size=(im_rh, im_rw), mode='area') # inference _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True) # resize and save matte matte = F.interpolate(matte, size=(im_h, im_w), mode='area') matte = matte[0][0].data.cpu().numpy() return Image.fromarray(((matte * 255).astype('uint8')), mode='L')