Heisenberg08's picture
minor changes
ca17fe1
import imp
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import cv2
import numpy as np
import torch
import torch.nn as nnst
import torchvision.transforms.functional as TF
from torchvision import transforms
from model import DoubleConv,UNET
convert_tensor = transforms.ToTensor()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNET(in_channels=3, out_channels=1).to(device)
model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
# model=torch.load("src//Unet_acc_94.pth",map_location=device)
def predict(img):
img=cv2.resize(img,(240,160))
test_img=convert_tensor(img).unsqueeze(0)
# print(test_img.shape)
preds=model(test_img.float())
preds=torch.sigmoid(preds)
preds=(preds > 0.5).float()
# print(preds.shape)
im=preds.squeeze(0).permute(1,2,0).detach()
# print(im.shape)
im=im.numpy()
return im
def blurr_image(input_image,preds):
mask=preds
inp=input_image
mask=np.resize(mask,(160,240))
mask=(mask>0.1)*255
mask=np.full((160,240),[mask],np.uint8)
mapping = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
image=cv2.resize(inp,(240,160))
blurred_original_image = cv2.GaussianBlur(image,(25,25),0)
blurred_img = np.where(mapping != (0,0,0),image,blurred_original_image)
blurred_img=cv2.cvtColor(blurred_img,cv2.COLOR_BGR2RGB)
inp=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
return inp,blurred_img
import streamlit as st
st.title("AI Portrait Mode")
st.markdown("Creator: [Pranav Kushare] (https://github.com/Pranav082001)")
# st.markdown(
# "Source code: [GitHub Repository](git link)")
# )
file=st.file_uploader("Please upload the image",type=["jpg","jpeg","png"])
check=st.checkbox("Dsiplay Mask", value=False)
print(file)
if file is None:
st.text("Please Upload an image")
else:
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
pred=predict(opencv_image)
inp_img,blurred=blurr_image(opencv_image,pred)
st.text("Original")
st.image(inp_img)
if check:
st.text("Mask!!")
st.image(pred)
st.text("Blurred")
st.image(blurred)