import torch from PIL import Image from torch.autograd import Variable from torchvision import transforms import numpy as np # opens and returns image file as a PIL image (0-255) def load_image(filename): img = Image.open(filename) return img # assumes data comes in batch form (ch, h, w) def save_image(filename, data): std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img = data.clone().numpy() img = ((img * std + mean).transpose(1, 2, 0)*255.0).clip(0, 255).astype("uint8") img = Image.fromarray(img) img.save(filename) # Calculate Gram matrix (G = FF^T) def gram(x): (bs, ch, h, w) = x.size() f = x.view(bs, ch, w*h) f_T = f.transpose(1, 2) G = f.bmm(f_T) / (ch * h * w) return G # using ImageNet values def normalize_tensor_transform(): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])