Spaces:
Sleeping
Sleeping
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
def preprocess_img_from_path(path_to_image, img_size, normalize=False): | |
img = Image.open(path_to_image) | |
return preprocess_img(img, img_size, normalize) | |
def preprocess_img(img: Image, img_size, normalize=False): | |
original_size = img.size | |
if normalize: | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
else: | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor() | |
]) | |
img = transform(img).unsqueeze(0) | |
return img, original_size | |
def postprocess_img(img, original_size, normalize=False): | |
img = img.detach().cpu().squeeze(0) | |
# Denormalize the image | |
if normalize: | |
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
img = img * std + mean | |
img = torch.clamp(img, 0, 1) | |
img = transforms.ToPILImage()(img) | |
img = img.resize(original_size, Image.Resampling.LANCZOS) | |
return img |