Rongjiehuang's picture
update
222619b
# '''
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
# '''
#
# import torch
# import torch.jit
# import torch.nn.functional as F
#
#
# @torch.jit.script
# def create_window(window_size: int, sigma: float, channel: int):
# '''
# Create 1-D gauss kernel
# :param window_size: the size of gauss kernel
# :param sigma: sigma of normal distribution
# :param channel: input channel
# :return: 1D kernel
# '''
# coords = torch.arange(window_size, dtype=torch.float)
# coords -= window_size // 2
#
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
# g /= g.sum()
#
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
# return g
#
#
# @torch.jit.script
# def _gaussian_filter(x, window_1d, use_padding: bool):
# '''
# Blur input with 1-D kernel
# :param x: batch of tensors to be blured
# :param window_1d: 1-D gauss kernel
# :param use_padding: padding image before conv
# :return: blured tensors
# '''
# C = x.shape[1]
# padding = 0
# if use_padding:
# window_size = window_1d.shape[3]
# padding = window_size // 2
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
# return out
#
#
# @torch.jit.script
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
# '''
# Calculate ssim index for X and Y
# :param X: images [B, C, H, N_bins]
# :param Y: images [B, C, H, N_bins]
# :param window: 1-D gauss kernel
# :param data_range: value range of input images. (usually 1.0 or 255)
# :param use_padding: padding image before conv
# :return:
# '''
#
# K1 = 0.01
# K2 = 0.03
# compensation = 1.0
#
# C1 = (K1 * data_range) ** 2
# C2 = (K2 * data_range) ** 2
#
# mu1 = _gaussian_filter(X, window, use_padding)
# mu2 = _gaussian_filter(Y, window, use_padding)
# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
# sigma12 = _gaussian_filter(X * Y, window, use_padding)
#
# mu1_sq = mu1.pow(2)
# mu2_sq = mu2.pow(2)
# mu1_mu2 = mu1 * mu2
#
# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
# sigma12 = compensation * (sigma12 - mu1_mu2)
#
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
# cs_map = cs_map.clamp_min(0.)
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
#
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
# cs = cs_map.mean(dim=(1, 2, 3))
#
# return ssim_val, cs
#
#
# @torch.jit.script
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
# '''
# interface of ms-ssim
# :param X: a batch of images, (N,C,H,W)
# :param Y: a batch of images, (N,C,H,W)
# :param window: 1-D gauss kernel
# :param data_range: value range of input images. (usually 1.0 or 255)
# :param weights: weights for different levels
# :param use_padding: padding image before conv
# :param eps: use for avoid grad nan.
# :return:
# '''
# levels = weights.shape[0]
# cs_vals = []
# ssim_vals = []
# for _ in range(levels):
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
# ssim_val = ssim_val.clamp_min(eps)
# cs = cs.clamp_min(eps)
# cs_vals.append(cs)
#
# ssim_vals.append(ssim_val)
# padding = (X.shape[2] % 2, X.shape[3] % 2)
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
#
# cs_vals = torch.stack(cs_vals, dim=0)
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
# return ms_ssim_val
#
#
# class SSIM(torch.jit.ScriptModule):
# __constants__ = ['data_range', 'use_padding']
#
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
# '''
# :param window_size: the size of gauss kernel
# :param window_sigma: sigma of normal distribution
# :param data_range: value range of input images. (usually 1.0 or 255)
# :param channel: input channels (default: 3)
# :param use_padding: padding image before conv
# '''
# super().__init__()
# assert window_size % 2 == 1, 'Window size must be odd.'
# window = create_window(window_size, window_sigma, channel)
# self.register_buffer('window', window)
# self.data_range = data_range
# self.use_padding = use_padding
#
# @torch.jit.script_method
# def forward(self, X, Y):
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
# return r[0]
#
#
# class MS_SSIM(torch.jit.ScriptModule):
# __constants__ = ['data_range', 'use_padding', 'eps']
#
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
# levels=None, eps=1e-8):
# '''
# class for ms-ssim
# :param window_size: the size of gauss kernel
# :param window_sigma: sigma of normal distribution
# :param data_range: value range of input images. (usually 1.0 or 255)
# :param channel: input channels
# :param use_padding: padding image before conv
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
# :param levels: number of downsampling
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
# '''
# super().__init__()
# assert window_size % 2 == 1, 'Window size must be odd.'
# self.data_range = data_range
# self.use_padding = use_padding
# self.eps = eps
#
# window = create_window(window_size, window_sigma, channel)
# self.register_buffer('window', window)
#
# if weights is None:
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
# weights = torch.tensor(weights, dtype=torch.float)
#
# if levels is not None:
# weights = weights[:levels]
# weights = weights / weights.sum()
#
# self.register_buffer('weights', weights)
#
# @torch.jit.script_method
# def forward(self, X, Y):
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
# use_padding=self.use_padding, eps=self.eps)
#
#
# if __name__ == '__main__':
# print('Simple Test')
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
# img1 = im / 255
# img2 = img1 * 0.5
#
# losser = SSIM(data_range=1.).cuda()
# loss = losser(img1, img2).mean()
#
# losser2 = MS_SSIM(data_range=1.).cuda()
# loss2 = losser2(img1, img2).mean()
#
# print(loss.item())
# print(loss2.item())
#
# if __name__ == '__main__':
# print('Training Test')
# import cv2
# import torch.optim
# import numpy as np
# import imageio
# import time
#
# out_test_video = False
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
# video_use_gif = False
#
# im = cv2.imread('test_img1.jpg', 1)
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
#
# if out_test_video:
# if video_use_gif:
# fps = 0.5
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
# suffix = '.gif'
# else:
# fps = 5
# out_wh = (im.shape[1], im.shape[0])
# suffix = '.mkv'
# video_last_time = time.perf_counter()
# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
#
# # 测试ssim
# print('Training SSIM')
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
# rand_im.requires_grad = True
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
# ssim_score = 0
# while ssim_score < 0.999:
# optim.zero_grad()
# loss = losser(rand_im, t_im)
# (-loss).sum().backward()
# ssim_score = loss.item()
# optim.step()
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
#
# if out_test_video:
# if time.perf_counter() - video_last_time > 1. / fps:
# video_last_time = time.perf_counter()
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
# if isinstance(out_frame, cv2.UMat):
# out_frame = out_frame.get()
# video.append_data(out_frame)
#
# cv2.imshow('ssim', r_im)
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
# cv2.waitKey(1)
#
# if out_test_video:
# video.close()
#
# # 测试ms_ssim
# if out_test_video:
# if video_use_gif:
# fps = 0.5
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
# suffix = '.gif'
# else:
# fps = 5
# out_wh = (im.shape[1], im.shape[0])
# suffix = '.mkv'
# video_last_time = time.perf_counter()
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
#
# print('Training MS_SSIM')
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
# rand_im.requires_grad = True
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
# ssim_score = 0
# while ssim_score < 0.999:
# optim.zero_grad()
# loss = losser(rand_im, t_im)
# (-loss).sum().backward()
# ssim_score = loss.item()
# optim.step()
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
#
# if out_test_video:
# if time.perf_counter() - video_last_time > 1. / fps:
# video_last_time = time.perf_counter()
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
# if isinstance(out_frame, cv2.UMat):
# out_frame = out_frame.get()
# video.append_data(out_frame)
#
# cv2.imshow('ms_ssim', r_im)
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
# cv2.waitKey(1)
#
# if out_test_video:
# video.close()
"""
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
"""
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
window = None
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
global window
if window is None:
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)