Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,046 Bytes
d65c9b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
from PIL import Image
from torchvision import transforms
from src.lpips import LPIPS
import torch.nn as nn
dev = 'cuda'
to_tensor_transform = transforms.Compose([transforms.ToTensor()])
mse_loss = nn.MSELoss()
def calculate_l2_difference(image1, image2, device = 'cuda'):
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
mse = mse_loss(image1, image2).item()
return mse
def calculate_psnr(image1, image2, device = 'cuda'):
max_value = 1.0
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
mse = mse_loss(image1, image2)
psnr = 10 * torch.log10(max_value**2 / mse).item()
return psnr
loss_fn = LPIPS(net_type='vgg').to(dev).eval()
def calculate_lpips(image1, image2, device = 'cuda'):
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
loss = loss_fn(image1, image2).item()
return loss
def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)):
if isinstance(image1, Image.Image):
image1 = image1.resize(size)
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = image2.resize(size)
image2 = to_tensor_transform(image2).to(device)
l2 = calculate_l2_difference(image1, image2, device)
psnr = calculate_psnr(image1, image2, device)
lpips = calculate_lpips(image1, image2, device)
return {"l2": l2, "psnr": psnr, "lpips": lpips}
def get_empty_metrics():
return {"l2": 0, "psnr": 0, "lpips": 0}
def print_results(results):
print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}") |