Spaces:
Runtime error
Runtime error
import torch | |
import torchvision | |
from torchvision import transforms | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from model import DoubleConv,UNET | |
import os | |
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | |
convert_tensor = transforms.ToTensor() | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# print(device) | |
model = UNET(in_channels=3, out_channels=1).to(device) | |
model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu')) | |
# test_img=np.array(Image.open("profilepic - Copy.jpeg").resize((160,240))) | |
test_img=Image.open("104.jpg").resize((240,160)) | |
# test_img=torch.tensor(test_img).permute(2,1,0) | |
# test_img=test_img.unsqueeze(0) | |
test_img=convert_tensor(test_img).unsqueeze(0) | |
print(test_img.shape) | |
preds=model(test_img.float()) | |
preds=torch.sigmoid(preds) | |
preds=(preds > 0.5).float() | |
print(preds.shape) | |
im=preds.squeeze(0).permute(1,2,0).detach() | |
print(im.shape) | |
fig,axs=plt.subplots(1,2) | |
axs[0].imshow(im) | |
axs[1].imshow(test_img.squeeze(0).permute(1,2,0).detach()) | |
plt.show() | |