PAN / utility.py
Tellll's picture
Update code and model to support NHWC input format
90e4acb
raw
history blame
2.9 kB
import math
import numpy as np
from scipy import signal
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
if sr.size(-1) == 3 and sr.size(1) > 3:
sr = sr.transpose((0, 3, 1, 2))
hr = hr.transpose((0, 3, 1, 2))
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
print("the dimention of sr image is not equal to hr's! ")
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
"""
2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
"""
m,n = [(ss-1.)/2. for ss in shape]
y,x = np.ogrid[-m:m+1,-n:n+1]
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
sumh = h.sum()
if sumh != 0:
h /= sumh
return h
def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
'''
X : y channel (i.e., luminance) of transformed YCbCr space of X
Y : y channel (i.e., luminance) of transformed YCbCr space of Y
'''
gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
shave = scale
if X.size(1) > 1:
gray_coeffs = [65.738, 129.057, 25.064]
convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
X = X.mul(convert).sum(dim=1)
Y = Y.mul(convert).sum(dim=1)
X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
window = gaussian_filter
ux = signal.convolve2d(X, window, mode='same', boundary='symm')
uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
vx = uxx - ux * ux
vy = uyy - uy * uy
vxy = uxy - ux * uy
C1 = (K1 * R) ** 2
C2 = (K2 * R) ** 2
A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
D = B1 * B2
S = (A1 * A2) / D
mssim = S.mean()
return mssim