Spaces:
Sleeping
Sleeping
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from safetensors.torch import load_file | |
def preprocess_img(img, img_size, normalize=False): | |
if type(img) == str: img = Image.open(img) | |
original_size = img.size | |
if normalize: | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
else: | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor() | |
]) | |
img = transform(img).unsqueeze(0) | |
return img, original_size | |
def postprocess_img(img, original_size, normalize=False): | |
img = img.detach().cpu().squeeze(0) | |
# Denormalize the image | |
if normalize: | |
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
img = img * std + mean | |
img = torch.clamp(img, 0, 1) | |
img = transforms.ToPILImage()(img) | |
img = img.resize(original_size, Image.Resampling.LANCZOS) | |
return img | |
def load_model_without_module(model, model_path, device): | |
state_dict = { | |
k[7:] if k.startswith('module.') else k: v | |
for k, v in load_file(model_path, device=device).items() | |
} | |
model.load_state_dict(state_dict) |