Spaces:
Runtime error
Runtime error
import imp | |
import os | |
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | |
import torch | |
import torch.nn as nn | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nnst | |
import torchvision.transforms.functional as TF | |
from torchvision import transforms | |
from model import DoubleConv,UNET | |
convert_tensor = transforms.ToTensor() | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = UNET(in_channels=3, out_channels=1).to(device) | |
model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu')) | |
# model=torch.load("src//Unet_acc_94.pth",map_location=device) | |
def predict(img): | |
img=cv2.resize(img,(240,160)) | |
test_img=convert_tensor(img).unsqueeze(0) | |
# print(test_img.shape) | |
preds=model(test_img.float()) | |
preds=torch.sigmoid(preds) | |
preds=(preds > 0.5).float() | |
# print(preds.shape) | |
im=preds.squeeze(0).permute(1,2,0).detach() | |
# print(im.shape) | |
im=im.numpy() | |
return im | |
def blurr_image(input_image,preds): | |
mask=preds | |
inp=input_image | |
mask=np.resize(mask,(160,240)) | |
mask=(mask>0.1)*255 | |
mask=np.full((160,240),[mask],np.uint8) | |
mapping = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) | |
image=cv2.resize(inp,(240,160)) | |
blurred_original_image = cv2.GaussianBlur(image,(25,25),0) | |
blurred_img = np.where(mapping != (0,0,0),image,blurred_original_image) | |
blurred_img=cv2.cvtColor(blurred_img,cv2.COLOR_BGR2RGB) | |
inp=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) | |
return inp,blurred_img | |
import streamlit as st | |
st.title("AI Portrait Mode") | |
st.markdown("Creator: [Pranav Kushare] (https://github.com/Pranav082001)") | |
# st.markdown( | |
# "Source code: [GitHub Repository](git link)") | |
# ) | |
file=st.file_uploader("Please upload the image",type=["jpg","jpeg","png"]) | |
check=st.checkbox("Dsiplay Mask", value=False) | |
print(file) | |
if file is None: | |
st.text("Please Upload an image") | |
else: | |
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8) | |
opencv_image = cv2.imdecode(file_bytes, 1) | |
pred=predict(opencv_image) | |
inp_img,blurred=blurr_image(opencv_image,pred) | |
st.text("Original") | |
st.image(inp_img) | |
if check: | |
st.text("Mask!!") | |
st.image(pred) | |
st.text("Blurred") | |
st.image(blurred) |