diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f64400b7ca62cceb317e7e4966391439fbfa516a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/pic/div2k_comparison.jpg filter=lfs diff=lfs merge=lfs -text +assets/pic/london2.jpg filter=lfs diff=lfs merge=lfs -text +assets/pic/main_framework.jpg filter=lfs diff=lfs merge=lfs -text +assets/pic/realsr_vis3.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/assets/mm-realsr/de_net.pth b/assets/mm-realsr/de_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..7f74e7a76f9f2dafb1cf2d5f9c7b2ab0a162d059 --- /dev/null +++ b/assets/mm-realsr/de_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6e77c1cb0dd51e01ef60bbf7b9d85b3f9399c3c1253889ddb73de1436231b1a +size 9424338 diff --git a/assets/pic/div2k_comparison.jpg b/assets/pic/div2k_comparison.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ff9d7a9cd189e351bc6a9325dd2d98dd86f8906e --- /dev/null +++ b/assets/pic/div2k_comparison.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d52daed3de6be211bda5e1c11b8d9352fee84fca53af53e10b3cddc86036db99 +size 6224264 diff --git a/assets/pic/gradio.png b/assets/pic/gradio.png new file mode 100644 index 0000000000000000000000000000000000000000..77314c7983ded13f13ea92cd7dd72857b3854112 Binary files /dev/null and b/assets/pic/gradio.png differ diff --git a/assets/pic/london2.jpg b/assets/pic/london2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7f0d621c00e06874b6e3a789b028ad9e69211ffc --- /dev/null +++ b/assets/pic/london2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6055d986cff9fe82d97b60f2ed728f00fb81600aa69541a97db80aa73cd906fa +size 6388509 diff --git a/assets/pic/main_framework.jpg b/assets/pic/main_framework.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2232fe62b3b14cafc99ad1234c5872b4acce6264 --- /dev/null +++ b/assets/pic/main_framework.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f089c6f09a44300d36cb3c9b6b4905ca4f801d63ae1907fa83ab66aec70ec893 +size 4684687 diff --git a/assets/pic/realsr_vis3.jpg b/assets/pic/realsr_vis3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b94ebf4ed0c82dcbb63b3fd3152d568799123fd8 --- /dev/null +++ b/assets/pic/realsr_vis3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07d0ffe838adf48958e3e910b467e1726d4a7a2f1623a9ae530f8db622fbb06 +size 5329675 diff --git a/basicsr/__init__.py b/basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28437544a254656cca7fb7021ef7bbf724cf2879 --- /dev/null +++ b/basicsr/__init__.py @@ -0,0 +1,12 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .test import * +from .train import * +from .utils import * +# from .version import __gitsha__, __version__ diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af6bcbd97bb3e4914c3c91dc53e0708bcac66075 --- /dev/null +++ b/basicsr/archs/__init__.py @@ -0,0 +1,24 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/basicsr/archs/__pycache__/__init__.cpython-310.pyc b/basicsr/archs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42cc1633c7eac9adb8918a45ab3f1750d5586c7c Binary files /dev/null and b/basicsr/archs/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/arch_util.cpython-310.pyc b/basicsr/archs/__pycache__/arch_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daa40709b2db10cc5e0d72c611a756f009f8cf8b Binary files /dev/null and b/basicsr/archs/__pycache__/arch_util.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd89bdf1ade589e50162e45f0c7846234f27c7c8 Binary files /dev/null and b/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc b/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2ca4bf0f8a5884810417320f39a773d9ad4f9f6 Binary files /dev/null and b/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc b/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b169e1f8ed6753fc0ae43a00e4eba15cc6dc1f Binary files /dev/null and b/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2f2508819ab09acb8be786f7d12ff7bd1778694 Binary files /dev/null and b/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc b/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5795c182b4625d4b8f09c1afd14535c58b079e9e Binary files /dev/null and b/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc b/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27f6981aa4d1ef80414ed6ffc16a5b6ee37396b2 Binary files /dev/null and b/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc b/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdf303bb2727c9745b7f0514e1c8818b8734388e Binary files /dev/null and b/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..370f1151128f00644f94748692a2d49a64976354 Binary files /dev/null and b/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba9d41be046570128c961352d76e3cb242536f1d Binary files /dev/null and b/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c9b74e1ea2520e762a73f0233061cf657aaee6 Binary files /dev/null and b/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc b/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f1a15d865cac226c02d1e6df6ef6ca08dab49f7 Binary files /dev/null and b/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc b/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ec83d83b4392a7aae3aa37b4a1050c61e9a50af Binary files /dev/null and b/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc b/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..624274a671dd1f7244d6e36cb955a31fd22102c5 Binary files /dev/null and b/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60b9c40c34382099eab3b639cf8c6f63b2751d8a Binary files /dev/null and b/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72b0e92593f0a2b4dc0aa33d78dbe85553778f60 Binary files /dev/null and b/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1163d4f73b0bb7acd1c41b6b205f43a2b2ca69bf Binary files /dev/null and b/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baf26142f38dc736cf702d2083f33402f55e9d73 Binary files /dev/null and b/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc b/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55dba13fb752495c0ef9769d1b1b74499f5f0e1c Binary files /dev/null and b/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc b/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea2f163422d4c8a1dc4002622188276c15b5e6e Binary files /dev/null and b/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc b/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2759ea197d29ec6e60a2026ad90fa705e4fde70f Binary files /dev/null and b/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc b/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccfafa0b0c9f270ada130831a489c34e31285a5d Binary files /dev/null and b/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc b/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b292810333816f0589ea6a48282fb761715d25c Binary files /dev/null and b/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc differ diff --git a/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc b/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49c328416326c40e0a93165d20fc842e6eb39140 Binary files /dev/null and b/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc differ diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..07fd7762814136c0bf5d34432f46722723e68e3f --- /dev/null +++ b/basicsr/archs/arch_util.py @@ -0,0 +1,355 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + +class PixelShufflePack(nn.Module): + """Pixel Shuffle upsample layer. + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of Conv layer to expand channels. + Returns: + Upsampled feature map. + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + """Initialize weights for PixelShufflePack.""" + default_init_weights(self, 1) + + def forward(self, x): + """Forward function for PixelShufflePack. + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + Returns: + Tensor: Forward results. + """ + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU() + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.relu(x) + out = self.conv2(x) + # out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + ``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution`` + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/basicsr/archs/basicvsr_arch.py b/basicsr/archs/basicvsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7b824eae108a9bcca57f1c14dd0d8afafc4f58 --- /dev/null +++ b/basicsr/archs/basicvsr_arch.py @@ -0,0 +1,336 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, flow_warp, make_layer +from .edvr_arch import PCDAlignment, TSAFusion +from .spynet_arch import SpyNet + + +@ARCH_REGISTRY.register() +class BasicVSR(nn.Module): + """A recurrent network for video SR. Now only x4 is supported. + + Args: + num_feat (int): Number of channels. Default: 64. + num_block (int): Number of residual blocks for each branch. Default: 15 + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + """ + + def __init__(self, num_feat=64, num_block=15, spynet_path=None): + super().__init__() + self.num_feat = num_feat + + # alignment + self.spynet = SpyNet(spynet_path) + + # propagation + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + + # reconstruction + self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True) + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + self.pixel_shuffle = nn.PixelShuffle(2) + + # activation functions + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def get_flow(self, x): + b, n, c, h, w = x.size() + + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) + + return flows_forward, flows_backward + + def forward(self, x): + """Forward function of BasicVSR. + + Args: + x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames. + """ + flows_forward, flows_backward = self.get_flow(x) + b, n, _, h, w = x.size() + + # backward branch + out_l = [] + feat_prop = x.new_zeros(b, self.num_feat, h, w) + for i in range(n - 1, -1, -1): + x_i = x[:, i, :, :, :] + if i < n - 1: + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.backward_trunk(feat_prop) + out_l.insert(0, feat_prop) + + # forward branch + feat_prop = torch.zeros_like(feat_prop) + for i in range(0, n): + x_i = x[:, i, :, :, :] + if i > 0: + flow = flows_forward[:, i - 1, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.forward_trunk(feat_prop) + + # upsample + out = torch.cat([out_l[i], feat_prop], dim=1) + out = self.lrelu(self.fusion(out)) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) + out += base + out_l[i] = out + + return torch.stack(out_l, dim=1) + + +class ConvResidualBlocks(nn.Module): + """Conv and residual block used in BasicVSR. + + Args: + num_in_ch (int): Number of input channels. Default: 3. + num_out_ch (int): Number of output channels. Default: 64. + num_block (int): Number of residual blocks. Default: 15. + """ + + def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15): + super().__init__() + self.main = nn.Sequential( + nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), + make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch)) + + def forward(self, fea): + return self.main(fea) + + +@ARCH_REGISTRY.register() +class IconVSR(nn.Module): + """IconVSR, proposed also in the BasicVSR paper. + + Args: + num_feat (int): Number of channels. Default: 64. + num_block (int): Number of residual blocks for each branch. Default: 15. + keyframe_stride (int): Keyframe stride. Default: 5. + temporal_padding (int): Temporal padding. Default: 2. + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + edvr_path (str): Path to the pretrained EDVR model. Default: None. + """ + + def __init__(self, + num_feat=64, + num_block=15, + keyframe_stride=5, + temporal_padding=2, + spynet_path=None, + edvr_path=None): + super().__init__() + + self.num_feat = num_feat + self.temporal_padding = temporal_padding + self.keyframe_stride = keyframe_stride + + # keyframe_branch + self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path) + # alignment + self.spynet = SpyNet(spynet_path) + + # propagation + self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) + + self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) + self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block) + + # reconstruction + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + self.pixel_shuffle = nn.PixelShuffle(2) + + # activation functions + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def pad_spatial(self, x): + """Apply padding spatially. + + Since the PCD module in EDVR requires that the resolution is a multiple + of 4, we apply padding to the input LR images if their resolution is + not divisible by 4. + + Args: + x (Tensor): Input LR sequence with shape (n, t, c, h, w). + Returns: + Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad). + """ + n, t, c, h, w = x.size() + + pad_h = (4 - h % 4) % 4 + pad_w = (4 - w % 4) % 4 + + # padding + x = x.view(-1, c, h, w) + x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') + + return x.view(n, t, c, h + pad_h, w + pad_w) + + def get_flow(self, x): + b, n, c, h, w = x.size() + + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) + + return flows_forward, flows_backward + + def get_keyframe_feature(self, x, keyframe_idx): + if self.temporal_padding == 2: + x = [x[:, [4, 3]], x, x[:, [-4, -5]]] + elif self.temporal_padding == 3: + x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]] + x = torch.cat(x, dim=1) + + num_frames = 2 * self.temporal_padding + 1 + feats_keyframe = {} + for i in keyframe_idx: + feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous()) + return feats_keyframe + + def forward(self, x): + b, n, _, h_input, w_input = x.size() + + x = self.pad_spatial(x) + h, w = x.shape[3:] + + keyframe_idx = list(range(0, n, self.keyframe_stride)) + if keyframe_idx[-1] != n - 1: + keyframe_idx.append(n - 1) # last frame is a keyframe + + # compute flow and keyframe features + flows_forward, flows_backward = self.get_flow(x) + feats_keyframe = self.get_keyframe_feature(x, keyframe_idx) + + # backward branch + out_l = [] + feat_prop = x.new_zeros(b, self.num_feat, h, w) + for i in range(n - 1, -1, -1): + x_i = x[:, i, :, :, :] + if i < n - 1: + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + if i in keyframe_idx: + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) + feat_prop = self.backward_fusion(feat_prop) + feat_prop = torch.cat([x_i, feat_prop], dim=1) + feat_prop = self.backward_trunk(feat_prop) + out_l.insert(0, feat_prop) + + # forward branch + feat_prop = torch.zeros_like(feat_prop) + for i in range(0, n): + x_i = x[:, i, :, :, :] + if i > 0: + flow = flows_forward[:, i - 1, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + if i in keyframe_idx: + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) + feat_prop = self.forward_fusion(feat_prop) + + feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1) + feat_prop = self.forward_trunk(feat_prop) + + # upsample + out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) + out += base + out_l[i] = out + + return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input] + + +class EDVRFeatureExtractor(nn.Module): + """EDVR feature extractor used in IconVSR. + + Args: + num_input_frame (int): Number of input frames. + num_feat (int): Number of feature channels + load_path (str): Path to the pretrained weights of EDVR. Default: None. + """ + + def __init__(self, num_input_frame, num_feat, load_path): + + super(EDVRFeatureExtractor, self).__init__() + + self.center_frame_idx = num_input_frame // 2 + + # extract pyramid features + self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1) + self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat) + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # pcd and tsa module + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8) + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + def forward(self, x): + b, n, c, h, w = x.size() + + # extract features for each frame + # L1 + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) + feat_l1 = self.feature_extraction(feat_l1) + # L2 + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) + # L3 + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) + + feat_l1 = feat_l1.view(b, n, -1, h, w) + feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2) + feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4) + + # PCD alignment + ref_feat_l = [ # reference feature list + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), + feat_l3[:, self.center_frame_idx, :, :, :].clone() + ] + aligned_feat = [] + for i in range(n): + nbr_feat_l = [ # neighboring feature list + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() + ] + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) + + # TSA fusion + return self.fusion(aligned_feat) diff --git a/basicsr/archs/basicvsrpp_arch.py b/basicsr/archs/basicvsrpp_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9952e4b441de0030d665a3db141774184f332f --- /dev/null +++ b/basicsr/archs/basicvsrpp_arch.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import warnings + +from basicsr.archs.arch_util import flow_warp +from basicsr.archs.basicvsr_arch import ConvResidualBlocks +from basicsr.archs.spynet_arch import SpyNet +from basicsr.ops.dcn import ModulatedDeformConvPack +from basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class BasicVSRPlusPlus(nn.Module): + """BasicVSR++ network structure. + + Support either x4 upsampling or same size output. Since DCN is used in this + model, it can only be used with CUDA enabled. If CUDA is not enabled, + feature alignment will be skipped. Besides, we adopt the official DCN + implementation and the version of torch need to be higher than 1.9. + + ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment`` + + Args: + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + num_blocks (int, optional): The number of residual blocks in each + propagation branch. Default: 7. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + is_low_res_input (bool, optional): Whether the input is low-resolution + or not. If False, the output resolution is equal to the input + resolution. Default: True. + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. + cpu_cache_length (int, optional): When the length of sequence is larger + than this value, the intermediate features are sent to CPU. This + saves GPU memory, but slows down the inference speed. You can + increase this number if you have a GPU with large memory. + Default: 100. + """ + + def __init__(self, + mid_channels=64, + num_blocks=7, + max_residue_magnitude=10, + is_low_res_input=True, + spynet_path=None, + cpu_cache_length=100): + + super().__init__() + self.mid_channels = mid_channels + self.is_low_res_input = is_low_res_input + self.cpu_cache_length = cpu_cache_length + + # optical flow + self.spynet = SpyNet(spynet_path) + + # feature extraction module + if is_low_res_input: + self.feat_extract = ConvResidualBlocks(3, mid_channels, 5) + else: + self.feat_extract = nn.Sequential( + nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True), + ConvResidualBlocks(mid_channels, mid_channels, 5)) + + # propagation branches + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + for i, module in enumerate(modules): + if torch.cuda.is_available(): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * mid_channels, + mid_channels, + 3, + padding=1, + deformable_groups=16, + max_residue_magnitude=max_residue_magnitude) + self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks) + + # upsampling module + self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5) + + self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True) + + self.pixel_shuffle = nn.PixelShuffle(2) + + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # check if the sequence is augmented by flipping + self.is_mirror_extended = False + + if len(self.deform_align) > 0: + self.is_with_alignment = True + else: + self.is_with_alignment = False + warnings.warn('Deformable alignment module is not added. ' + 'Probably your CUDA is not configured correctly. DCN can only ' + 'be used with CUDA enabled. Alignment is skipped now.') + + def check_if_mirror_extended(self, lqs): + """Check whether the input is a mirror-extended sequence. + + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame. + + Args: + lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w). + """ + + if lqs.size(1) % 2 == 0: + lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1) + if torch.norm(lqs_1 - lqs_2.flip(1)) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \ + (current to previous). 'flows_backward' corresponds to the flows used for backward-time \ + propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = flows_backward.flip(1) + else: + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + if self.cpu_cache: + flows_backward = flows_backward.cpu() + flows_forward = flows_forward.cpu() + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name): + """Propagate the latent features throughout the sequence. + + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + + Return: + dict(list[tensor]): A dictionary containing all the propagated \ + features. Each key in the dictionary corresponds to a \ + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + + feat_prop = flows.new_zeros(n, self.mid_channels, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if self.cpu_cache: + feat_current = feat_current.cuda() + feat_prop = feat_prop.cuda() + # second-order deformable alignment + if i > 0 and self.is_with_alignment: + flow_n1 = flows[:, flow_idx[i], :, :, :] + if self.cpu_cache: + flow_n1 = flow_n1.cuda() + + cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + flow_n2 = torch.zeros_like(flow_n1) + cond_n2 = torch.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + if self.cpu_cache: + feat_n2 = feat_n2.cuda() + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + if self.cpu_cache: + flow_n2 = flow_n2.cuda() + + flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1)) + cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1)) + + # flow-guided deformable convolution + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop] + if self.cpu_cache: + feat = [f.cuda() for f in feat] + + feat = torch.cat(feat, dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + feats[module_name].append(feat_prop) + + if self.cpu_cache: + feats[module_name][-1] = feats[module_name][-1].cpu() + torch.cuda.empty_cache() + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + return feats + + def upsample(self, lqs, feats): + """Compute the output image given the features. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + feats (dict): The features from the propagation branches. + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + outputs = [] + num_outputs = len(feats['spatial']) + + mapping_idx = list(range(0, num_outputs)) + mapping_idx += mapping_idx[::-1] + + for i in range(0, lqs.size(1)): + hr = [feats[k].pop(0) for k in feats if k != 'spatial'] + hr.insert(0, feats['spatial'][mapping_idx[i]]) + hr = torch.cat(hr, dim=1) + if self.cpu_cache: + hr = hr.cuda() + + hr = self.reconstruction(hr) + hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr))) + hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr))) + hr = self.lrelu(self.conv_hr(hr)) + hr = self.conv_last(hr) + if self.is_low_res_input: + hr += self.img_upsample(lqs[:, i, :, :, :]) + else: + hr += lqs[:, i, :, :, :] + + if self.cpu_cache: + hr = hr.cpu() + torch.cuda.empty_cache() + + outputs.append(hr) + + return torch.stack(outputs, dim=1) + + def forward(self, lqs): + """Forward function for BasicVSR++. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lqs.size() + + # whether to cache the features in CPU + self.cpu_cache = True if t > self.cpu_cache_length else False + + if self.is_low_res_input: + lqs_downsample = lqs.clone() + else: + lqs_downsample = F.interpolate( + lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4) + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lqs) + + feats = {} + # compute spatial features + if self.cpu_cache: + feats['spatial'] = [] + for i in range(0, t): + feat = self.feat_extract(lqs[:, i, :, :, :]).cpu() + feats['spatial'].append(feat) + torch.cuda.empty_cache() + else: + feats_ = self.feat_extract(lqs.view(-1, c, h, w)) + h, w = feats_.shape[2:] + feats_ = feats_.view(n, t, -1, h, w) + feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)] + + # compute optical flow using the low-res inputs + assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, ( + 'The height and width of low-res inputs must be at least 64, ' + f'but got {h} and {w}.') + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # feature propgation + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + module = f'{direction}_{iter_}' + + feats[module] = [] + + if direction == 'backward': + flows = flows_backward + elif flows_forward is not None: + flows = flows_forward + else: + flows = flows_backward.flip(1) + + feats = self.propagate(feats, flows, module) + if self.cpu_cache: + del flows + torch.cuda.empty_cache() + + return self.upsample(lqs, feats) + + +class SecondOrderDeformableAlignment(ModulatedDeformConvPack): + """Second-order deformable alignment module. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1), + ) + + self.init_offset() + + def init_offset(self): + + def _constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + _constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1) + offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + + +# if __name__ == '__main__': +# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth' +# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda() +# input = torch.rand(1, 2, 3, 64, 64).cuda() +# output = model(input) +# print('===================') +# print(output.shape) diff --git a/basicsr/archs/degradat_arch.py b/basicsr/archs/degradat_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..ce09ad666a90f175fb6268435073b314df543813 --- /dev/null +++ b/basicsr/archs/degradat_arch.py @@ -0,0 +1,90 @@ +from torch import nn as nn + +from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights +from basicsr.utils.registry import ARCH_REGISTRY + +@ARCH_REGISTRY.register() +class DEResNet(nn.Module): + """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore + As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal', + resnet arch works for image quality estimation. + Args: + num_in_ch (int): channel number of inputs. Default: 3. + num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise). + degradation_embed_size (int): embedding size of each degradation vector. + degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid. + num_feats (list): channel number of each stage. + num_blocks (list): residual block of each stage. + downscales (list): downscales of each stage. + """ + + def __init__(self, + num_in_ch=3, + num_degradation=2, + degradation_degree_actv='sigmoid', + num_feats=(64, 128, 256, 512), + num_blocks=(2, 2, 2, 2), + downscales=(2, 2, 2, 1)): + super(DEResNet, self).__init__() + + assert isinstance(num_feats, list) + assert isinstance(num_blocks, list) + assert isinstance(downscales, list) + assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales) + + num_stage = len(num_feats) + + self.conv_first = nn.ModuleList() + for _ in range(num_degradation): + self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1)) + self.body = nn.ModuleList() + for _ in range(num_degradation): + body = list() + for stage in range(num_stage): + for _ in range(num_blocks[stage]): + body.append(ResidualBlockNoBN(num_feats[stage])) + if downscales[stage] == 1: + if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]: + body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1)) + continue + elif downscales[stage] == 2: + body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1)) + else: + raise NotImplementedError + self.body.append(nn.Sequential(*body)) + + # self.body = nn.Sequential(*body) + + self.num_degradation = num_degradation + self.fc_degree = nn.ModuleList() + if degradation_degree_actv == 'sigmoid': + actv = nn.Sigmoid + elif degradation_degree_actv == 'tanh': + actv = nn.Tanh + else: + raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, ' + f'{degradation_degree_actv} is not supported yet.') + for _ in range(num_degradation): + self.fc_degree.append( + nn.Sequential( + nn.Linear(num_feats[-1], 512), + nn.ReLU(inplace=True), + nn.Linear(512, 1), + actv(), + )) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1) + + def forward(self, x): + degrees = [] + for i in range(self.num_degradation): + x_out = self.conv_first[i](x) + feat = self.body[i](x_out) + feat = self.avg_pool(feat) + feat = feat.squeeze(-1).squeeze(-1) + # for i in range(self.num_degradation): + degrees.append(self.fc_degree[i](feat).squeeze(-1)) + + return degrees diff --git a/basicsr/archs/dfdnet_arch.py b/basicsr/archs/dfdnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4751434c2f17efbb682d9344951604602d853aaa --- /dev/null +++ b/basicsr/archs/dfdnet_arch.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.spectral_norm import spectral_norm + +from basicsr.utils.registry import ARCH_REGISTRY +from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization +from .vgg_arch import VGGFeatureExtractor + + +class SFTUpBlock(nn.Module): + """Spatial feature transform (SFT) with upsampling block. + + Args: + in_channel (int): Number of input channels. + out_channel (int): Number of output channels. + kernel_size (int): Kernel size in convolutions. Default: 3. + padding (int): Padding in convolutions. Default: 1. + """ + + def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): + super(SFTUpBlock, self).__init__() + self.conv1 = nn.Sequential( + Blur(in_channel), + spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.04, True), + # The official codes use two LeakyReLU here, so 0.04 for equivalent + ) + self.convup = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), + nn.LeakyReLU(0.2, True), + ) + + # for SFT scale and shift + self.scale_block = nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))) + self.shift_block = nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid()) + # The official codes use sigmoid for shift block, do not know why + + def forward(self, x, updated_feat): + out = self.conv1(x) + # SFT + scale = self.scale_block(updated_feat) + shift = self.shift_block(updated_feat) + out = out * scale + shift + # upsample + out = self.convup(out) + return out + + +@ARCH_REGISTRY.register() +class DFDNet(nn.Module): + """DFDNet: Deep Face Dictionary Network. + + It only processes faces with 512x512 size. + + Args: + num_feat (int): Number of feature channels. + dict_path (str): Path to the facial component dictionary. + """ + + def __init__(self, num_feat, dict_path): + super().__init__() + self.parts = ['left_eye', 'right_eye', 'nose', 'mouth'] + # part_sizes: [80, 80, 50, 110] + channel_sizes = [128, 256, 512, 512] + self.feature_sizes = np.array([256, 128, 64, 32]) + self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4'] + self.flag_dict_device = False + + # dict + self.dict = torch.load(dict_path) + + # vgg face extractor + self.vgg_extractor = VGGFeatureExtractor( + layer_name_list=self.vgg_layers, + vgg_type='vgg19', + use_input_norm=True, + range_norm=True, + requires_grad=False) + + # attention block for fusing dictionary features and input features + self.attn_blocks = nn.ModuleDict() + for idx, feat_size in enumerate(self.feature_sizes): + for name in self.parts: + self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx]) + + # multi scale dilation block + self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1]) + + # upsampling and reconstruction + self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8) + self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4) + self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2) + self.upsample3 = SFTUpBlock(num_feat * 2, num_feat) + self.upsample4 = nn.Sequential( + spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat), + UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh()) + + def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size): + """swap the features from the dictionary.""" + # get the original vgg features + part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone() + # resize original vgg features + part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False) + # use adaptive instance normalization to adjust color and illuminations + dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat) + # get similarity scores + similarity_score = F.conv2d(part_resize_feat, dict_feat) + similarity_score = F.softmax(similarity_score.view(-1), dim=0) + # select the most similar features in the dict (after norm) + select_idx = torch.argmax(similarity_score) + swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4]) + # attention + attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat) + attn_feat = attn * swap_feat + # update features + updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat + return updated_feat + + def put_dict_to_device(self, x): + if self.flag_dict_device is False: + for k, v in self.dict.items(): + for kk, vv in v.items(): + self.dict[k][kk] = vv.to(x) + self.flag_dict_device = True + + def forward(self, x, part_locations): + """ + Now only support testing with batch size = 0. + + Args: + x (Tensor): Input faces with shape (b, c, 512, 512). + part_locations (list[Tensor]): Part locations. + """ + self.put_dict_to_device(x) + # extract vggface features + vgg_features = self.vgg_extractor(x) + # update vggface features using the dictionary for each part + updated_vgg_features = [] + batch = 0 # only supports testing with batch size = 0 + for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes): + dict_features = self.dict[f'{f_size}'] + vgg_feat = vgg_features[vgg_layer] + updated_feat = vgg_feat.clone() + + # swap features from dictionary + for part_idx, part_name in enumerate(self.parts): + location = (part_locations[part_idx][batch] // (512 / f_size)).int() + updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name, + f_size) + + updated_vgg_features.append(updated_feat) + + vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4']) + # use updated vgg features to modulate the upsampled features with + # SFT (Spatial Feature Transform) scaling and shifting manner. + upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3]) + upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2]) + upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1]) + upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0]) + out = self.upsample4(upsampled_feat) + + return out diff --git a/basicsr/archs/dfdnet_util.py b/basicsr/archs/dfdnet_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b4dc0ff738c76852e830b32fffbe65bffb5ddf50 --- /dev/null +++ b/basicsr/archs/dfdnet_util.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.utils.spectral_norm import spectral_norm + + +class BlurFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) + return grad_input + + @staticmethod + def backward(ctx, gradgrad_output): + kernel, _ = ctx.saved_tensors + grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) + return grad_input, None, None + + +class BlurFunction(Function): + + @staticmethod + def forward(ctx, x, kernel, kernel_flip): + ctx.save_for_backward(kernel, kernel_flip) + output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) + return output + + @staticmethod + def backward(ctx, grad_output): + kernel, kernel_flip = ctx.saved_tensors + grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) + return grad_input, None, None + + +blur = BlurFunction.apply + + +class Blur(nn.Module): + + def __init__(self, channel): + super().__init__() + kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) + kernel = kernel.view(1, 1, 3, 3) + kernel = kernel / kernel.sum() + kernel_flip = torch.flip(kernel, [2, 3]) + + self.kernel = kernel.repeat(channel, 1, 1, 1) + self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) + + def forward(self, x): + return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + n, c = size[:2] + feat_var = feat.view(n, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(n, c, 1, 1) + feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def AttentionBlock(in_channel): + return nn.Sequential( + spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), + spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) + + +def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): + """Conv block used in MSDilationBlock.""" + + return nn.Sequential( + spectral_norm( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias)), + nn.LeakyReLU(0.2), + spectral_norm( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=((kernel_size - 1) // 2) * dilation, + bias=bias)), + ) + + +class MSDilationBlock(nn.Module): + """Multi-scale dilation block.""" + + def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): + super(MSDilationBlock, self).__init__() + + self.conv_blocks = nn.ModuleList() + for i in range(4): + self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) + self.conv_fusion = spectral_norm( + nn.Conv2d( + in_channels * 4, + in_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias)) + + def forward(self, x): + out = [] + for i in range(4): + out.append(self.conv_blocks[i](x)) + out = torch.cat(out, 1) + out = self.conv_fusion(out) + x + return out + + +class UpResBlock(nn.Module): + + def __init__(self, in_channel): + super(UpResBlock, self).__init__() + self.body = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + ) + + def forward(self, x): + out = x + self.body(x) + return out diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..33f9a8f1b25c2052cd3ba801534861a425752e69 --- /dev/null +++ b/basicsr/archs/discriminator_arch.py @@ -0,0 +1,150 @@ +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + +from basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class VGGStyleDiscriminator(nn.Module): + """VGG style discriminator with input size 128 x 128 or 256 x 256. + + It is used to train SRGAN, ESRGAN, and VideoGAN. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features.Default: 64. + """ + + def __init__(self, num_in_ch, num_feat, input_size=128): + super(VGGStyleDiscriminator, self).__init__() + self.input_size = input_size + assert self.input_size == 128 or self.input_size == 256, ( + f'input size must be 128 or 256, but received {input_size}') + + self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) + + self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) + self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) + + self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) + self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) + + self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + if self.input_size == 256: + self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) + self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True) + self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) + self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True) + + self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.') + + feat = self.lrelu(self.conv0_0(x)) + feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2 + + feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) + feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4 + + feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) + feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8 + + feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) + feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16 + + feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) + feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32 + + if self.input_size == 256: + feat = self.lrelu(self.bn5_0(self.conv5_0(feat))) + feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64 + + # spatial size: (4, 4) + feat = feat.view(feat.size(0), -1) + feat = self.lrelu(self.linear1(feat)) + out = self.linear2(feat) + return out + + +@ARCH_REGISTRY.register(suffix='basicsr') +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/basicsr/archs/duf_arch.py b/basicsr/archs/duf_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b3ab7df4d890c9220d74ed8c461ad9d155120a --- /dev/null +++ b/basicsr/archs/duf_arch.py @@ -0,0 +1,276 @@ +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY + + +class DenseBlocksTemporalReduce(nn.Module): + """A concatenation of 3 dense blocks with reduction in temporal dimension. + + Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks. + + Args: + num_feat (int): Number of channels in the blocks. Default: 64. + num_grow_ch (int): Growing factor of the dense blocks. Default: 32 + adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation. + Set to false if you want to train from scratch. Default: False. + """ + + def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False): + super(DenseBlocksTemporalReduce, self).__init__() + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.temporal_reduce1 = nn.Sequential( + nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True), + nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + self.temporal_reduce2 = nn.Sequential( + nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + num_grow_ch, + num_feat + num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + self.temporal_reduce3 = nn.Sequential( + nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + 2 * num_grow_ch, + num_feat + 2 * num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), + nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor with shape (b, num_feat, t, h, w). + + Returns: + Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w). + """ + x1 = self.temporal_reduce1(x) + x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1) + + x2 = self.temporal_reduce2(x1) + x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1) + + x3 = self.temporal_reduce3(x2) + x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1) + + return x3 + + +class DenseBlocks(nn.Module): + """ A concatenation of N dense blocks. + + Args: + num_feat (int): Number of channels in the blocks. Default: 64. + num_grow_ch (int): Growing factor of the dense blocks. Default: 32. + num_block (int): Number of dense blocks. The values are: + DUF-S (16 layers): 3 + DUF-M (18 layers): 9 + DUF-L (52 layers): 21 + adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation. + Set to false if you want to train from scratch. Default: False. + """ + + def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False): + super(DenseBlocks, self).__init__() + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.dense_blocks = nn.ModuleList() + for i in range(0, num_block): + self.dense_blocks.append( + nn.Sequential( + nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + i * num_grow_ch, + num_feat + i * num_grow_ch, (1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), + nn.ReLU(inplace=True), + nn.Conv3d( + num_feat + i * num_grow_ch, + num_grow_ch, (3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True))) + + def forward(self, x): + """ + Args: + x (Tensor): Input tensor with shape (b, num_feat, t, h, w). + + Returns: + Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w). + """ + for i in range(0, len(self.dense_blocks)): + y = self.dense_blocks[i](x) + x = torch.cat((x, y), 1) + return x + + +class DynamicUpsamplingFilter(nn.Module): + """Dynamic upsampling filter used in DUF. + + Reference: https://github.com/yhjo09/VSR-DUF + + It only supports input with 3 channels. And it applies the same filters to 3 channels. + + Args: + filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5). + """ + + def __init__(self, filter_size=(5, 5)): + super(DynamicUpsamplingFilter, self).__init__() + if not isinstance(filter_size, tuple): + raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}') + if len(filter_size) != 2: + raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.') + # generate a local expansion filter, similar to im2col + self.filter_size = filter_size + filter_prod = np.prod(filter_size) + expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw) + self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels + + def forward(self, x, filters): + """Forward function for DynamicUpsamplingFilter. + + Args: + x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w). + filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w). + filter_prod: prod of filter kernel size, e.g., 1*5*5=25. + upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling. + e.g., for x 4 upsampling, upsampling_square= 4*4 = 16 + + Returns: + Tensor: Filtered image with shape (n, 3*upsampling_square, h, w) + """ + n, filter_prod, upsampling_square, h, w = filters.size() + kh, kw = self.filter_size + expanded_input = F.conv2d( + x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w) + expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1, + 2) # (n, h, w, 3, filter_prod) + filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square] + out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square) + return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w) + + +@ARCH_REGISTRY.register() +class DUF(nn.Module): + """Network architecture for DUF + + ``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation`` + + Reference: https://github.com/yhjo09/VSR-DUF + + For all the models below, 'adapt_official_weights' is only necessary when + loading the weights converted from the official TensorFlow weights. + Please set it to False if you are training the model from scratch. + + There are three models with different model size: DUF16Layers, DUF28Layers, + and DUF52Layers. This class is the base class for these models. + + Args: + scale (int): The upsampling factor. Default: 4. + num_layer (int): The number of layers. Default: 52. + adapt_official_weights_weights (bool): Whether to adapt the weights + translated from the official implementation. Set to false if you + want to train from scratch. Default: False. + """ + + def __init__(self, scale=4, num_layer=52, adapt_official_weights=False): + super(DUF, self).__init__() + self.scale = scale + if adapt_official_weights: + eps = 1e-3 + momentum = 1e-3 + else: # pytorch default values + eps = 1e-05 + momentum = 0.1 + + self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + self.dynamic_filter = DynamicUpsamplingFilter((5, 5)) + + if num_layer == 16: + num_block = 3 + num_grow_ch = 32 + elif num_layer == 28: + num_block = 9 + num_grow_ch = 16 + elif num_layer == 52: + num_block = 21 + num_grow_ch = 16 + else: + raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.') + + self.dense_block1 = DenseBlocks( + num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch, + adapt_official_weights=adapt_official_weights) # T = 7 + self.dense_block2 = DenseBlocksTemporalReduce( + 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1 + channels = 64 + num_grow_ch * num_block + num_grow_ch * 3 + self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum) + self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) + + self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + + self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + self.conv3d_f2 = nn.Conv3d( + 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True) + + def forward(self, x): + """ + Args: + x (Tensor): Input with shape (b, 7, c, h, w) + + Returns: + Tensor: Output with shape (b, c, h * scale, w * scale) + """ + num_batches, num_imgs, _, h, w = x.size() + + x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D + x_center = x[:, :, num_imgs // 2, :, :] + + x = self.conv3d1(x) + x = self.dense_block1(x) + x = self.dense_block2(x) + x = F.relu(self.bn3d2(x), inplace=True) + x = F.relu(self.conv3d2(x), inplace=True) + + # residual image + res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) + + # filter + filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) + filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1) + + # dynamic filter + out = self.dynamic_filter(x_center, filter_) + out += res.squeeze_(2) + out = F.pixel_shuffle(out, self.scale) + + return out diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..fe20e772587d74c67fffb40f3b4731cf4f42268b --- /dev/null +++ b/basicsr/archs/ecbsr_arch.py @@ -0,0 +1,275 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from basicsr.utils.registry import ARCH_REGISTRY + + +class SeqConv3x3(nn.Module): + """The re-parameterizable block used in the ECBSR architecture. + + ``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices`` + + Reference: https://github.com/xindongzhang/ECBSR + + Args: + seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian. + in_channels (int): Channel number of input. + out_channels (int): Channel number of output. + depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1. + """ + + def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1): + super(SeqConv3x3, self).__init__() + self.seq_type = seq_type + self.in_channels = in_channels + self.out_channels = out_channels + + if self.seq_type == 'conv1x1-conv3x3': + self.mid_planes = int(out_channels * depth_multiplier) + conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.seq_type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 1, 0] = 2.0 + self.mask[i, 0, 2, 0] = 1.0 + self.mask[i, 0, 0, 2] = -1.0 + self.mask[i, 0, 1, 2] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.seq_type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 0, 1] = 2.0 + self.mask[i, 0, 0, 2] = 1.0 + self.mask[i, 0, 2, 0] = -1.0 + self.mask[i, 0, 2, 1] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.seq_type == 'conv1x1-laplacian': + conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale and bias + scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + bias = torch.randn(self.out_channels) * 1e-3 + bias = torch.reshape(bias, (self.out_channels, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_channels): + self.mask[i, 0, 0, 1] = 1.0 + self.mask[i, 0, 1, 0] = 1.0 + self.mask[i, 0, 1, 2] = 1.0 + self.mask[i, 0, 2, 1] = 1.0 + self.mask[i, 0, 1, 1] = -4.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + else: + raise ValueError('The type of seqconv is not supported!') + + def forward(self, x): + if self.seq_type == 'conv1x1-conv3x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + else: + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels) + return y1 + + def rep_params(self): + device = self.k0.get_device() + if device < 0: + device = None + + if self.seq_type == 'conv1x1-conv3x3': + # re-param conv kernel + rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1 + else: + tmp = self.scale * self.mask + k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device) + for i in range(self.out_channels): + k1[i, i, :, :] = tmp[i, 0, :, :] + b1 = self.bias + # re-param conv kernel + rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1 + return rep_weight, rep_bias + + +class ECB(nn.Module): + """The ECB block used in the ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + in_channels (int): Channel number of input. + out_channels (int): Channel number of output. + depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1. + act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu. + with_idt (bool): Whether to use identity connection. Default: False. + """ + + def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False): + super(ECB, self).__init__() + + self.depth_multiplier = depth_multiplier + self.in_channels = in_channels + self.out_channels = out_channels + self.act_type = act_type + + if with_idt and (self.in_channels == self.out_channels): + self.with_idt = True + else: + self.with_idt = False + + self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1) + self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels) + + if self.act_type == 'prelu': + self.act = nn.PReLU(num_parameters=self.out_channels) + elif self.act_type == 'relu': + self.act = nn.ReLU(inplace=True) + elif self.act_type == 'rrelu': + self.act = nn.RReLU(lower=-0.05, upper=0.05) + elif self.act_type == 'softplus': + self.act = nn.Softplus() + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + + def forward(self, x): + if self.training: + y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x) + if self.with_idt: + y += x + else: + rep_weight, rep_bias = self.rep_params() + y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1) + if self.act_type != 'linear': + y = self.act(y) + return y + + def rep_params(self): + weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias + weight1, bias1 = self.conv1x1_3x3.rep_params() + weight2, bias2 = self.conv1x1_sbx.rep_params() + weight3, bias3 = self.conv1x1_sby.rep_params() + weight4, bias4 = self.conv1x1_lpl.rep_params() + rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), ( + bias0 + bias1 + bias2 + bias3 + bias4) + + if self.with_idt: + device = rep_weight.get_device() + if device < 0: + device = None + weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device) + for i in range(self.out_channels): + weight_idt[i, i, 1, 1] = 1.0 + bias_idt = 0.0 + rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt + return rep_weight, rep_bias + + +@ARCH_REGISTRY.register() +class ECBSR(nn.Module): + """ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_block (int): Block number in the trunk network. + num_channel (int): Channel number. + with_idt (bool): Whether use identity in convolution layers. + act_type (str): Activation type. + scale (int): Upsampling factor. + """ + + def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale): + super(ECBSR, self).__init__() + self.num_in_ch = num_in_ch + self.scale = scale + + backbone = [] + backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] + for _ in range(num_block): + backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] + backbone += [ + ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt) + ] + + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(scale) + + def forward(self, x): + if self.num_in_ch > 1: + shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1) + else: + shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times) + y = self.backbone(x) + shortcut + y = self.upsampler(y) + return y diff --git a/basicsr/archs/edsr_arch.py b/basicsr/archs/edsr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..b80566f11fbd4782d68eee8fbf7da686f89dc4e7 --- /dev/null +++ b/basicsr/archs/edsr_arch.py @@ -0,0 +1,61 @@ +import torch +from torch import nn as nn + +from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer +from basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register() +class EDSR(nn.Module): + """EDSR network structure. + + Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. + Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64. + num_block (int): Block number in the trunk network. Default: 16. + upscale (int): Upsampling factor. Support 2^n and 3. + Default: 4. + res_scale (float): Used to scale the residual in residual block. + Default: 1. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + num_feat=64, + num_block=16, + upscale=4, + res_scale=1, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040)): + super(EDSR, self).__init__() + + self.img_range = img_range + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) + self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + self.mean = self.mean.type_as(x) + + x = (x - self.mean) * self.img_range + x = self.conv_first(x) + res = self.conv_after_body(self.body(x)) + res += x + + x = self.conv_last(self.upsample(res)) + x = x / self.img_range + self.mean + + return x diff --git a/basicsr/archs/edvr_arch.py b/basicsr/archs/edvr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c4f47deb383d4fe6108b97436c9dfb1e541583 --- /dev/null +++ b/basicsr/archs/edvr_arch.py @@ -0,0 +1,382 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer + + +class PCDAlignment(nn.Module): + """Alignment module using Pyramid, Cascading and Deformable convolution + (PCD). It is used in EDVR. + + ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks`` + + Args: + num_feat (int): Channel number of middle features. Default: 64. + deformable_groups (int): Deformable groups. Defaults: 8. + """ + + def __init__(self, num_feat=64, deformable_groups=8): + super(PCDAlignment, self).__init__() + + # Pyramid has three levels: + # L3: level 3, 1/4 spatial size + # L2: level 2, 1/2 spatial size + # L1: level 1, original spatial size + self.offset_conv1 = nn.ModuleDict() + self.offset_conv2 = nn.ModuleDict() + self.offset_conv3 = nn.ModuleDict() + self.dcn_pack = nn.ModuleDict() + self.feat_conv = nn.ModuleDict() + + # Pyramids + for i in range(3, 0, -1): + level = f'l{i}' + self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + if i == 3: + self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + else: + self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + if i < 3: + self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + + # Cascading dcn + self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, nbr_feat_l, ref_feat_l): + """Align neighboring frame features to the reference frame features. + + Args: + nbr_feat_l (list[Tensor]): Neighboring feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + ref_feat_l (list[Tensor]): Reference feature list. It + contains three pyramid levels (L1, L2, L3), + each with shape (b, c, h, w). + + Returns: + Tensor: Aligned features. + """ + # Pyramids + upsampled_offset, upsampled_feat = None, None + for i in range(3, 0, -1): + level = f'l{i}' + offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 3: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1))) + offset = self.lrelu(self.offset_conv3[level](offset)) + + feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset) + if i < 3: + feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1)) + if i > 1: + feat = self.lrelu(feat) + + if i > 1: # upsample offset and features + # x2: when we upsample the offset, we should also enlarge + # the magnitude. + upsampled_offset = self.upsample(offset) * 2 + upsampled_feat = self.upsample(feat) + + # Cascading + offset = torch.cat([feat, ref_feat_l[0]], dim=1) + offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset)))) + feat = self.lrelu(self.cas_dcnpack(feat, offset)) + return feat + + +class TSAFusion(nn.Module): + """Temporal Spatial Attention (TSA) fusion module. + + Temporal: Calculate the correlation between center frame and + neighboring frames; + Spatial: It has 3 pyramid levels, the attention is similar to SFT. + (SFT: Recovering realistic texture in image super-resolution by deep + spatial feature transform.) + + Args: + num_feat (int): Channel number of middle features. Default: 64. + num_frame (int): Number of frames. Default: 5. + center_frame_idx (int): The index of center frame. Default: 2. + """ + + def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2): + super(TSAFusion, self).__init__() + self.center_frame_idx = center_frame_idx + # temporal attention (before fusion conv) + self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # spatial attention (after fusion conv) + self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) + self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1) + self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1) + self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1) + self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1) + self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1) + self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, aligned_feat): + """ + Args: + aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w). + + Returns: + Tensor: Features after TSA with the shape (b, c, h, w). + """ + b, t, c, h, w = aligned_feat.size() + # temporal attention + embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone()) + embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w)) + embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w) + + corr_l = [] # correlation list + for i in range(t): + emb_neighbor = embedding[:, i, :, :, :] + corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w) + corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w) + corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w) + corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w) + corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w) + aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob + + # fusion + feat = self.lrelu(self.feat_fusion(aligned_feat)) + + # spatial attention + attn = self.lrelu(self.spatial_attn1(aligned_feat)) + attn_max = self.max_pool(attn) + attn_avg = self.avg_pool(attn) + attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1))) + # pyramid levels + attn_level = self.lrelu(self.spatial_attn_l1(attn)) + attn_max = self.max_pool(attn_level) + attn_avg = self.avg_pool(attn_level) + attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1))) + attn_level = self.lrelu(self.spatial_attn_l3(attn_level)) + attn_level = self.upsample(attn_level) + + attn = self.lrelu(self.spatial_attn3(attn)) + attn_level + attn = self.lrelu(self.spatial_attn4(attn)) + attn = self.upsample(attn) + attn = self.spatial_attn5(attn) + attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn))) + attn = torch.sigmoid(attn) + + # after initialization, * 2 makes (attn * 2) to be close to 1. + feat = feat * attn * 2 + attn_add + return feat + + +class PredeblurModule(nn.Module): + """Pre-dublur module. + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + hr_in (bool): Whether the input has high resolution. Default: False. + """ + + def __init__(self, num_in_ch=3, num_feat=64, hr_in=False): + super(PredeblurModule, self).__init__() + self.hr_in = hr_in + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + if self.hr_in: + # downsample x4 by stride conv + self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + # generate feature pyramid + self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + + self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat) + self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)]) + + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + feat_l1 = self.lrelu(self.conv_first(x)) + if self.hr_in: + feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1)) + feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1)) + + # generate feature pyramid + feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1)) + feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2)) + + feat_l3 = self.upsample(self.resblock_l3(feat_l3)) + feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3 + feat_l2 = self.upsample(self.resblock_l2_2(feat_l2)) + + for i in range(2): + feat_l1 = self.resblock_l1[i](feat_l1) + feat_l1 = feat_l1 + feat_l2 + for i in range(2, 5): + feat_l1 = self.resblock_l1[i](feat_l1) + return feat_l1 + + +@ARCH_REGISTRY.register() +class EDVR(nn.Module): + """EDVR network structure for video super-resolution. + + Now only support X4 upsampling factor. + + ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks`` + + Args: + num_in_ch (int): Channel number of input image. Default: 3. + num_out_ch (int): Channel number of output image. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_frame (int): Number of input frames. Default: 5. + deformable_groups (int): Deformable groups. Defaults: 8. + num_extract_block (int): Number of blocks for feature extraction. + Default: 5. + num_reconstruct_block (int): Number of blocks for reconstruction. + Default: 10. + center_frame_idx (int): The index of center frame. Frame counting from + 0. Default: Middle of input frames. + hr_in (bool): Whether the input has high resolution. Default: False. + with_predeblur (bool): Whether has predeblur module. + Default: False. + with_tsa (bool): Whether has TSA module. Default: True. + """ + + def __init__(self, + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_frame=5, + deformable_groups=8, + num_extract_block=5, + num_reconstruct_block=10, + center_frame_idx=None, + hr_in=False, + with_predeblur=False, + with_tsa=True): + super(EDVR, self).__init__() + if center_frame_idx is None: + self.center_frame_idx = num_frame // 2 + else: + self.center_frame_idx = center_frame_idx + self.hr_in = hr_in + self.with_predeblur = with_predeblur + self.with_tsa = with_tsa + + # extract features for each frame + if self.with_predeblur: + self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in) + self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1) + else: + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + + # extract pyramid features + self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat) + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + # pcd and tsa module + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups) + if self.with_tsa: + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx) + else: + self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1) + + # reconstruction + self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat) + # upsample + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(2) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, x): + b, t, c, h, w = x.size() + if self.hr_in: + assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.') + else: + assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.') + + x_center = x[:, self.center_frame_idx, :, :, :].contiguous() + + # extract features for each frame + # L1 + if self.with_predeblur: + feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w))) + if self.hr_in: + h, w = h // 4, w // 4 + else: + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) + + feat_l1 = self.feature_extraction(feat_l1) + # L2 + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) + # L3 + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) + + feat_l1 = feat_l1.view(b, t, -1, h, w) + feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2) + feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4) + + # PCD alignment + ref_feat_l = [ # reference feature list + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), + feat_l3[:, self.center_frame_idx, :, :, :].clone() + ] + aligned_feat = [] + for i in range(t): + nbr_feat_l = [ # neighboring feature list + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() + ] + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) + + if not self.with_tsa: + aligned_feat = aligned_feat.view(b, -1, h, w) + feat = self.fusion(aligned_feat) + + out = self.reconstruction(feat) + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + if self.hr_in: + base = x_center + else: + base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) + out += base + return out diff --git a/basicsr/archs/hifacegan_arch.py b/basicsr/archs/hifacegan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..098e3ed4306eb19ae9da705c0af580a6f74c6cb9 --- /dev/null +++ b/basicsr/archs/hifacegan_arch.py @@ -0,0 +1,260 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer + + +class SPADEGenerator(BaseNetwork): + """Generator with SPADEResBlock""" + + def __init__(self, + num_in_ch=3, + num_feat=64, + use_vae=False, + z_dim=256, + crop_size=512, + norm_g='spectralspadesyncbatch3x3', + is_train=True, + init_train_phase=3): # progressive training disabled + super().__init__() + self.nf = num_feat + self.input_nc = num_in_ch + self.is_train = is_train + self.train_phase = init_train_phase + + self.scale_ratio = 5 # hardcoded now + self.sw = crop_size // (2**self.scale_ratio) + self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0 + + if use_vae: + # In case of VAE, we will sample from random z vector + self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh) + else: + # Otherwise, we make the network deterministic by starting with + # downsampled segmentation map instead of random z + self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1) + + self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + + self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g) + + self.ups = nn.ModuleList([ + SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g), + SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g), + SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g), + SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g) + ]) + + self.to_rgbs = nn.ModuleList([ + nn.Conv2d(8 * self.nf, 3, 3, padding=1), + nn.Conv2d(4 * self.nf, 3, 3, padding=1), + nn.Conv2d(2 * self.nf, 3, 3, padding=1), + nn.Conv2d(1 * self.nf, 3, 3, padding=1) + ]) + + self.up = nn.Upsample(scale_factor=2) + + def encode(self, input_tensor): + """ + Encode input_tensor into feature maps, can be overridden in derived classes + Default: nearest downsampling of 2**5 = 32 times + """ + h, w = input_tensor.size()[-2:] + sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio + x = F.interpolate(input_tensor, size=(sh, sw)) + return self.fc(x) + + def forward(self, x): + # In oroginal SPADE, seg means a segmentation map, but here we use x instead. + seg = x + + x = self.encode(x) + x = self.head_0(x, seg) + + x = self.up(x) + x = self.g_middle_0(x, seg) + x = self.g_middle_1(x, seg) + + if self.is_train: + phase = self.train_phase + 1 + else: + phase = len(self.to_rgbs) + + for i in range(phase): + x = self.up(x) + x = self.ups[i](x, seg) + + x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) + x = torch.tanh(x) + + return x + + def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'): + """ + A helper class for subspace visualization. Input and seg are different images. + For the first n levels (including encoder) we use input, for the rest we use seg. + + If mode = 'progressive', the output's like: AAABBB + If mode = 'one_plug', the output's like: AAABAA + If mode = 'one_ablate', the output's like: BBBABB + """ + + if seg is None: + return self.forward(input_x) + + if self.is_train: + phase = self.train_phase + 1 + else: + phase = len(self.to_rgbs) + + if mode == 'progressive': + n = max(min(n, 4 + phase), 0) + guide_list = [input_x] * n + [seg] * (4 + phase - n) + elif mode == 'one_plug': + n = max(min(n, 4 + phase - 1), 0) + guide_list = [seg] * (4 + phase) + guide_list[n] = input_x + elif mode == 'one_ablate': + if n > 3 + phase: + return self.forward(input_x) + guide_list = [input_x] * (4 + phase) + guide_list[n] = seg + + x = self.encode(guide_list[0]) + x = self.head_0(x, guide_list[1]) + + x = self.up(x) + x = self.g_middle_0(x, guide_list[2]) + x = self.g_middle_1(x, guide_list[3]) + + for i in range(phase): + x = self.up(x) + x = self.ups[i](x, guide_list[4 + i]) + + x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1)) + x = torch.tanh(x) + + return x + + +@ARCH_REGISTRY.register() +class HiFaceGAN(SPADEGenerator): + """ + HiFaceGAN: SPADEGenerator with a learnable feature encoder + Current encoder design: LIPEncoder + """ + + def __init__(self, + num_in_ch=3, + num_feat=64, + use_vae=False, + z_dim=256, + crop_size=512, + norm_g='spectralspadesyncbatch3x3', + is_train=True, + init_train_phase=3): + super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase) + self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio) + + def encode(self, input_tensor): + return self.lip_encoder(input_tensor) + + +@ARCH_REGISTRY.register() +class HiFaceGANDiscriminator(BaseNetwork): + """ + Inspired by pix2pixHD multiscale discriminator. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + conditional_d (bool): Whether use conditional discriminator. + Default: True. + num_d (int): Number of Multiscale discriminators. Default: 3. + n_layers_d (int): Number of downsample layers in each D. Default: 4. + num_feat (int): Channel number of base intermediate features. + Default: 64. + norm_d (str): String to determine normalization layers in D. + Choices: [spectral][instance/batch/syncbatch] + Default: 'spectralinstance'. + keep_features (bool): Keep intermediate features for matching loss, etc. + Default: True. + """ + + def __init__(self, + num_in_ch=3, + num_out_ch=3, + conditional_d=True, + num_d=2, + n_layers_d=4, + num_feat=64, + norm_d='spectralinstance', + keep_features=True): + super().__init__() + self.num_d = num_d + + input_nc = num_in_ch + if conditional_d: + input_nc += num_out_ch + + for i in range(num_d): + subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features) + self.add_module(f'discriminator_{i}', subnet_d) + + def downsample(self, x): + return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) + + # Returns list of lists of discriminator outputs. + # The final result is of size opt.num_d x opt.n_layers_D + def forward(self, x): + result = [] + for _, _net_d in self.named_children(): + out = _net_d(x) + result.append(out) + x = self.downsample(x) + + return result + + +class NLayerDiscriminator(BaseNetwork): + """Defines the PatchGAN discriminator with the specified arguments.""" + + def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features): + super().__init__() + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + nf = num_feat + self.keep_features = keep_features + + norm_layer = get_nonspade_norm_layer(norm_d) + sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]] + + for n in range(1, n_layers_d): + nf_prev = nf + nf = min(nf * 2, 512) + stride = 1 if n == n_layers_d - 1 else 2 + sequence += [[ + norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)), + nn.LeakyReLU(0.2, False) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + # We divide the layers into groups to extract intermediate layer outputs + for n in range(len(sequence)): + self.add_module('model' + str(n), nn.Sequential(*sequence[n])) + + def forward(self, x): + results = [x] + for submodel in self.children(): + intermediate_output = submodel(results[-1]) + results.append(intermediate_output) + + if self.keep_features: + return results[1:] + else: + return results[-1] diff --git a/basicsr/archs/hifacegan_util.py b/basicsr/archs/hifacegan_util.py new file mode 100644 index 0000000000000000000000000000000000000000..35cbef3f532fcc6aab0fa57ab316a546d3a17bd5 --- /dev/null +++ b/basicsr/archs/hifacegan_util.py @@ -0,0 +1,255 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +# Warning: spectral norm could be buggy +# under eval mode and multi-GPU inference +# A workaround is sticking to single-GPU inference and train mode +from torch.nn.utils import spectral_norm + + +class SPADE(nn.Module): + + def __init__(self, config_text, norm_nc, label_nc): + super().__init__() + + assert config_text.startswith('spade') + parsed = re.search('spade(\\D+)(\\d)x\\d', config_text) + param_free_norm_type = str(parsed.group(1)) + ks = int(parsed.group(2)) + + if param_free_norm_type == 'instance': + self.param_free_norm = nn.InstanceNorm2d(norm_nc) + elif param_free_norm_type == 'syncbatch': + print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') + self.param_free_norm = nn.InstanceNorm2d(norm_nc) + elif param_free_norm_type == 'batch': + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + else: + raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE') + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 if norm_nc > 128 else norm_nc + + pw = ks // 2 + self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) + + def forward(self, x, segmap): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * gamma + beta + + return out + + +class SPADEResnetBlock(nn.Module): + """ + ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that + it takes in the segmentation map as input, learns the skip connection if necessary, + and applies normalization first and then convolution. + This architecture seemed like a standard architecture for unconditional or + class-conditional GAN architecture using residual block. + The code was inspired from https://github.com/LMescheder/GAN_stability. + """ + + def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # apply spectral norm if specified + if 'spectral' in norm_g: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + + # define normalization layers + spade_config_str = norm_g.replace('spectral', '') + self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) + self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, semantic_nc) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.act(self.norm_0(x, seg))) + dx = self.conv_1(self.act(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + def act(self, x): + return F.leaky_relu(x, 2e-1) + + +class BaseNetwork(nn.Module): + """ A basis for hifacegan archs with custom initialization """ + + def init_weights(self, init_type='normal', gain=0.02): + + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError(f'initialization method [{init_type}] is not implemented') + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + def forward(self, x): + pass + + +def lip2d(x, logit, kernel=3, stride=2, padding=1): + weight = logit.exp() + return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding) + + +class SoftGate(nn.Module): + COEFF = 12.0 + + def forward(self, x): + return torch.sigmoid(x).mul(self.COEFF) + + +class SimplifiedLIP(nn.Module): + + def __init__(self, channels): + super(SimplifiedLIP, self).__init__() + self.logit = nn.Sequential( + nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True), + SoftGate()) + + def init_layer(self): + self.logit[0].weight.data.fill_(0.0) + + def forward(self, x): + frac = lip2d(x, self.logit(x)) + return frac + + +class LIPEncoder(BaseNetwork): + """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)""" + + def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d): + super().__init__() + self.sw = sw + self.sh = sh + self.max_ratio = 16 + # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold + kw = 3 + pw = (kw - 1) // 2 + + model = [ + nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False), + norm_layer(ngf), + nn.ReLU(), + ] + cur_ratio = 1 + for i in range(n_2xdown): + next_ratio = min(cur_ratio * 2, self.max_ratio) + model += [ + SimplifiedLIP(ngf * cur_ratio), + nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw), + norm_layer(ngf * next_ratio), + ] + cur_ratio = next_ratio + if i < n_2xdown - 1: + model += [nn.ReLU(inplace=True)] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +def get_nonspade_norm_layer(norm_type='instance'): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, 'out_channels'): + return getattr(layer, 'out_channels') + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer): + nonlocal norm_type + if norm_type.startswith('spectral'): + layer = spectral_norm(layer) + subnorm_type = norm_type[len('spectral'):] + + if subnorm_type == 'none' or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, 'bias', None) is not None: + delattr(layer, 'bias') + layer.register_parameter('bias', None) + + if subnorm_type == 'batch': + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'sync_batch': + print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') + # norm_layer = SynchronizedBatchNorm2d( + # get_out_channel(layer), affine=True) + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + elif subnorm_type == 'instance': + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + else: + raise ValueError(f'normalization layer {subnorm_type} is not recognized') + + return nn.Sequential(layer, norm_layer) + + print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.') + return add_norm_layer diff --git a/basicsr/archs/inception.py b/basicsr/archs/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..de1abef67270dc1aba770943b53577029141f527 --- /dev/null +++ b/basicsr/archs/inception.py @@ -0,0 +1,307 @@ +# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501 +# For FID metric + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.model_zoo import load_url +from torchvision import models + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling features + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3. + + Args: + output_blocks (list[int]): Indices of blocks to return features of. + Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input (bool): If true, bilinearly resizes input to width and + height 299 before feeding input to model. As the network + without fully connected layers is fully convolutional, it + should be able to handle inputs of arbitrary size, so resizing + might not be strictly needed. Default: True. + normalize_input (bool): If true, scales the input from range (0, 1) + to the range the pretrained Inception network expects, + namely (-1, 1). Default: True. + requires_grad (bool): If true, parameters of the model require + gradients. Possibly useful for finetuning the network. + Default: False. + use_fid_inception (bool): If true, uses the pretrained Inception + model used in Tensorflow's FID implementation. + If false, uses the pretrained Inception model available in + torchvision. The FID Inception model has different weights + and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get + comparable results. Default: True. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, ('Last possible output block index is 3') + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + try: + inception = models.inception_v3(pretrained=True, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, x): + """Get Inception feature maps. + + Args: + x (Tensor): Input tensor of shape (b, 3, h, w). + Values are expected to be in range (-1, 1). You can also input + (0, 1) with setting normalize_input = True. + + Returns: + list[Tensor]: Corresponding to the selected output block, sorted + ascending by index. + """ + output = [] + + if self.resize_input: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + output.append(x) + + if idx == self.last_needed_block: + break + + return output + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation. + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + try: + inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False) + + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + if os.path.exists(LOCAL_FID_WEIGHTS): + state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage) + else: + state_dict = load_url(FID_WEIGHTS_URL, progress=True) + + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/basicsr/archs/rcan_arch.py b/basicsr/archs/rcan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..48872e6800006d885f56f90dd2f0a2bd16e513d9 --- /dev/null +++ b/basicsr/archs/rcan_arch.py @@ -0,0 +1,135 @@ +import torch +from torch import nn as nn + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import Upsample, make_layer + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + """ + + def __init__(self, num_feat, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class RCAB(nn.Module): + """Residual Channel Attention Block (RCAB) used in RCAN. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + res_scale (float): Scale the residual. Default: 1. + """ + + def __init__(self, num_feat, squeeze_factor=16, res_scale=1): + super(RCAB, self).__init__() + self.res_scale = res_scale + + self.rcab = nn.Sequential( + nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), + ChannelAttention(num_feat, squeeze_factor)) + + def forward(self, x): + res = self.rcab(x) * self.res_scale + return res + x + + +class ResidualGroup(nn.Module): + """Residual Group of RCAB. + + Args: + num_feat (int): Channel number of intermediate features. + num_block (int): Block number in the body network. + squeeze_factor (int): Channel squeeze factor. Default: 16. + res_scale (float): Scale the residual. Default: 1. + """ + + def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): + super(ResidualGroup, self).__init__() + + self.residual_group = make_layer( + RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) + self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + + def forward(self, x): + res = self.conv(self.residual_group(x)) + return res + x + + +@ARCH_REGISTRY.register() +class RCAN(nn.Module): + """Residual Channel Attention Networks. + + ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks`` + + Reference: https://github.com/yulunzhang/RCAN + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64. + num_group (int): Number of ResidualGroup. Default: 10. + num_block (int): Number of RCAB in ResidualGroup. Default: 16. + squeeze_factor (int): Channel squeeze factor. Default: 16. + upscale (int): Upsampling factor. Support 2^n and 3. + Default: 4. + res_scale (float): Used to scale the residual in residual block. + Default: 1. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + num_feat=64, + num_group=10, + num_block=16, + squeeze_factor=16, + upscale=4, + res_scale=1, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040)): + super(RCAN, self).__init__() + + self.img_range = img_range + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + ResidualGroup, + num_group, + num_feat=num_feat, + num_block=num_block, + squeeze_factor=squeeze_factor, + res_scale=res_scale) + self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + self.mean = self.mean.type_as(x) + + x = (x - self.mean) * self.img_range + x = self.conv_first(x) + res = self.conv_after_body(self.body(x)) + res += x + + x = self.conv_last(self.upsample(res)) + x = x / self.img_range + self.mean + + return x diff --git a/basicsr/archs/ridnet_arch.py b/basicsr/archs/ridnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..85bb9ae0348e27dd6c797c03f8d9ec43f8b0b829 --- /dev/null +++ b/basicsr/archs/ridnet_arch.py @@ -0,0 +1,180 @@ +import torch +import torch.nn as nn + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, make_layer + + +class MeanShift(nn.Conv2d): + """ Data normalization with mean and std. + + Args: + rgb_range (int): Maximum value of RGB. + rgb_mean (list[float]): Mean for RGB channels. + rgb_std (list[float]): Std for RGB channels. + sign (int): For subtraction, sign is -1, for addition, sign is 1. + Default: -1. + requires_grad (bool): Whether to update the self.weight and self.bias. + Default: True. + """ + + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) + self.weight.data.div_(std.view(3, 1, 1, 1)) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) + self.bias.data.div_(std) + self.requires_grad = requires_grad + + +class EResidualBlockNoBN(nn.Module): + """Enhanced Residual block without BN. + + There are three convolution layers in residual branch. + """ + + def __init__(self, in_channels, out_channels): + super(EResidualBlockNoBN, self).__init__() + + self.body = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 1, 1, 0), + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.body(x) + out = self.relu(out + x) + return out + + +class MergeRun(nn.Module): + """ Merge-and-run unit. + + This unit contains two branches with different dilated convolutions, + followed by a convolution to process the concatenated features. + + Paper: Real Image Denoising with Feature Attention + Ref git repo: https://github.com/saeed-anwar/RIDNet + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): + super(MergeRun, self).__init__() + + self.dilation1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True)) + self.dilation2 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True)) + + self.aggregation = nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True)) + + def forward(self, x): + dilation1 = self.dilation1(x) + dilation2 = self.dilation2(x) + out = torch.cat([dilation1, dilation2], dim=1) + out = self.aggregation(out) + out = out + x + return out + + +class ChannelAttention(nn.Module): + """Channel attention. + + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: + """ + + def __init__(self, mid_channels, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class EAM(nn.Module): + """Enhancement attention modules (EAM) in RIDNet. + + This module contains a merge-and-run unit, a residual block, + an enhanced residual block and a feature attention unit. + + Attributes: + merge: The merge-and-run unit. + block1: The residual block. + block2: The enhanced residual block. + ca: The feature/channel attention unit. + """ + + def __init__(self, in_channels, mid_channels, out_channels): + super(EAM, self).__init__() + + self.merge = MergeRun(in_channels, mid_channels) + self.block1 = ResidualBlockNoBN(mid_channels) + self.block2 = EResidualBlockNoBN(mid_channels, out_channels) + self.ca = ChannelAttention(out_channels) + # The residual block in the paper contains a relu after addition. + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.merge(x) + out = self.relu(self.block1(out)) + out = self.block2(out) + out = self.ca(out) + return out + + +@ARCH_REGISTRY.register() +class RIDNet(nn.Module): + """RIDNet: Real Image Denoising with Feature Attention. + + Ref git repo: https://github.com/saeed-anwar/RIDNet + + Args: + in_channels (int): Channel number of inputs. + mid_channels (int): Channel number of EAM modules. + Default: 64. + out_channels (int): Channel number of outputs. + num_block (int): Number of EAM. Default: 4. + img_range (float): Image range. Default: 255. + rgb_mean (tuple[float]): Image mean in RGB orders. + Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. + """ + + def __init__(self, + in_channels, + mid_channels, + out_channels, + num_block=4, + img_range=255., + rgb_mean=(0.4488, 0.4371, 0.4040), + rgb_std=(1.0, 1.0, 1.0)): + super(RIDNet, self).__init__() + + self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std) + self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1) + + self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) + self.body = make_layer( + EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels) + self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + res = self.sub_mean(x) + res = self.tail(self.body(self.relu(self.head(res)))) + res = self.add_mean(res) + + out = x + res + return out diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..63d07080c2ec1305090c59b7bfbbda2b003b18e4 --- /dev/null +++ b/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/basicsr/archs/spynet_arch.py b/basicsr/archs/spynet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7af133daef0496b79a57517e1942d06f2d0061 --- /dev/null +++ b/basicsr/archs/spynet_arch.py @@ -0,0 +1,96 @@ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import flow_warp + + +class BasicModule(nn.Module): + """Basic Module for SpyNet. + """ + + def __init__(self): + super(BasicModule, self).__init__() + + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + return self.basic_module(tensor_input) + + +@ARCH_REGISTRY.register() +class SpyNet(nn.Module): + """SpyNet architecture. + + Args: + load_path (str): path for pretrained SpyNet. Default: None. + """ + + def __init__(self, load_path=None): + super(SpyNet, self).__init__() + self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def preprocess(self, tensor_input): + tensor_output = (tensor_input - self.mean) / self.std + return tensor_output + + def process(self, ref, supp): + flow = [] + + ref = [self.preprocess(ref)] + supp = [self.preprocess(supp)] + + for level in range(5): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + flow = ref[0].new_zeros( + [ref[0].size(0), 2, + int(math.floor(ref[0].size(2) / 2.0)), + int(math.floor(ref[0].size(3) / 2.0))]) + + for level in range(len(ref)): + upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + + if upsampled_flow.size(2) != ref[level].size(2): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') + if upsampled_flow.size(3) != ref[level].size(3): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') + + flow = self.basic_module[level](torch.cat([ + ref[level], + flow_warp( + supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), + upsampled_flow + ], 1)) + upsampled_flow + + return flow + + def forward(self, ref, supp): + assert ref.size() == supp.size() + + h, w = ref.size(2), ref.size(3) + w_floor = math.floor(math.ceil(w / 32.0) * 32.0) + h_floor = math.floor(math.ceil(h / 32.0) * 32.0) + + ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + + flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) + + flow[:, 0, :, :] *= float(w) / float(w_floor) + flow[:, 1, :, :] *= float(h) / float(h_floor) + + return flow diff --git a/basicsr/archs/srresnet_arch.py b/basicsr/archs/srresnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..7f571557cd7d9ba8791bd6462fccf648c57186d2 --- /dev/null +++ b/basicsr/archs/srresnet_arch.py @@ -0,0 +1,65 @@ +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer + + +@ARCH_REGISTRY.register() +class MSRResNet(nn.Module): + """Modified SRResNet. + + A compacted version modified from SRResNet in + "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" + It uses residual blocks without BN, similar to EDSR. + Currently, it supports x2, x3 and x4 upsampling scale factor. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_block (int): Block number in the body network. Default: 16. + upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): + super(MSRResNet, self).__init__() + self.upscale = upscale + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) + + # upsampling + if self.upscale in [2, 3]: + self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(self.upscale) + elif self.upscale == 4: + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(2) + + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + # initialization + default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) + if self.upscale == 4: + default_init_weights(self.upconv2, 0.1) + + def forward(self, x): + feat = self.lrelu(self.conv_first(x)) + out = self.body(feat) + + if self.upscale == 4: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) + elif self.upscale in [2, 3]: + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) + + out = self.conv_last(self.lrelu(self.conv_hr(out))) + base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) + out += base + return out diff --git a/basicsr/archs/srvgg_arch.py b/basicsr/archs/srvgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fe5ceb40ed9edd35d81ee17aff86f2e3d9adb4 --- /dev/null +++ b/basicsr/archs/srvgg_arch.py @@ -0,0 +1,70 @@ +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register(suffix='basicsr') +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out diff --git a/basicsr/archs/stylegan2_arch.py b/basicsr/archs/stylegan2_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab37f5a33a2ef21641de35109c16b511a6df163 --- /dev/null +++ b/basicsr/archs/stylegan2_arch.py @@ -0,0 +1,799 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from basicsr.ops.upfirdn2d import upfirdn2d +from basicsr.utils.registry import ARCH_REGISTRY + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +def make_resample_kernel(k): + """Make resampling kernel for UpFirDn. + + Args: + k (list[int]): A list indicating the 1D resample kernel magnitude. + + Returns: + Tensor: 2D resampled kernel. + """ + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] # to 2D kernel, outer product + # normalize + k /= k.sum() + return k + + +class UpFirDnUpsample(nn.Module): + """Upsample, FIR filter, and downsample (upsampole version). + + References: + 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501 + 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501 + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Upsampling scale factor. Default: 2. + """ + + def __init__(self, resample_kernel, factor=2): + super(UpFirDnUpsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) * (factor**2) + self.factor = factor + + pad = self.kernel.shape[0] - factor + self.pad = ((pad + 1) // 2 + factor - 1, pad // 2) + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(factor={self.factor})') + + +class UpFirDnDownsample(nn.Module): + """Upsample, FIR filter, and downsample (downsampole version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + factor (int): Downsampling scale factor. Default: 2. + """ + + def __init__(self, resample_kernel, factor=2): + super(UpFirDnDownsample, self).__init__() + self.kernel = make_resample_kernel(resample_kernel) + self.factor = factor + + pad = self.kernel.shape[0] - factor + self.pad = ((pad + 1) // 2, pad // 2) + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(factor={self.factor})') + + +class UpFirDnSmooth(nn.Module): + """Upsample, FIR filter, and downsample (smooth version). + + Args: + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. + upsample_factor (int): Upsampling scale factor. Default: 1. + downsample_factor (int): Downsampling scale factor. Default: 1. + kernel_size (int): Kernel size: Default: 1. + """ + + def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1): + super(UpFirDnSmooth, self).__init__() + self.upsample_factor = upsample_factor + self.downsample_factor = downsample_factor + self.kernel = make_resample_kernel(resample_kernel) + if upsample_factor > 1: + self.kernel = self.kernel * (upsample_factor**2) + + if upsample_factor > 1: + pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1) + self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1) + elif downsample_factor > 1: + pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1) + self.pad = ((pad + 1) // 2, pad // 2) + else: + raise NotImplementedError + + def forward(self, x): + out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}' + f', downsample_factor={self.downsample_factor})') + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1), + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + if self.sample_mode == 'upsample': + self.smooth = UpFirDnSmooth( + resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size) + elif self.sample_mode == 'downsample': + self.smooth = UpFirDnSmooth( + resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size) + elif self.sample_mode is None: + pass + else: + raise ValueError(f'Wrong sample mode {self.sample_mode}, ' + "supported ones are ['upsample', 'downsample', None].") + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = x.view(1, b * c, h, w) + weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size) + out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + out = self.smooth(out) + elif self.sample_mode == 'downsample': + x = self.smooth(x) + x = x.view(1, b * c, *x.shape[2:4]) + out = F.conv2d(x, weight, padding=0, stride=2, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + else: + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=(1, 3, 3, 1)): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + resample_kernel=resample_kernel) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. Default: (1, 3, 3, 1). + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)): + super(ToRGB, self).__init__() + if upsample: + self.upsample = UpFirDnUpsample(resample_kernel, factor=2) + else: + self.upsample = None + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = self.upsample(skip) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2Generator(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1): + super(StyleGAN2Generator, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + resample_kernel=resample_kernel, + )) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + resample_kernel=resample_kernel)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kernel to 2D resample kernel. + Default: (1, 3, 3, 1). + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + resample_kernel=(1, 3, 3, 1), + bias=True, + activate=True): + layers = [] + # downsample + if downsample: + layers.append( + UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)) + stride = 2 + self.padding = 0 + else: + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + resample_kernel (list[int]): A list indicating the 1D resample + kernel magnitude. A cross production will be applied to + extent 1D resample kernel to 2D resample kernel. + Default: (1, 3, 3, 1). + """ + + def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True) + self.skip = ConvLayer( + in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2Discriminator(nn.Module): + """StyleGAN2 Discriminator. + + Args: + out_size (int): The spatial size of outputs. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kernel to 2D resample kernel. Default: (1, 3, 3, 1). + stddev_group (int): For group stddev statistics. Default: 4. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1): + super(StyleGAN2Discriminator, self).__init__() + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + + log_size = int(math.log(out_size, 2)) + + conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)] + + in_channels = channels[f'{out_size}'] + for i in range(log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + conv_body.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + self.conv_body = nn.Sequential(*conv_body) + + self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True) + self.final_linear = nn.Sequential( + EqualLinear( + channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'), + EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None), + ) + self.stddev_group = stddev_group + self.stddev_feat = 1 + + def forward(self, x): + out = self.conv_body(x) + + b, c, h, w = out.shape + # concatenate a group stddev statistics to out + group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size + stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, h, w) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + out = out.view(b, -1) + out = self.final_linear(out) + + return out diff --git a/basicsr/archs/stylegan2_bilinear_arch.py b/basicsr/archs/stylegan2_bilinear_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..2395170411f9d11f2798ac03cf6ec6eb32fe5e43 --- /dev/null +++ b/basicsr/archs/stylegan2_bilinear_arch.py @@ -0,0 +1,614 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from basicsr.utils.registry import ARCH_REGISTRY + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + interpolation_mode='bilinear'): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode='bilinear'): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + interpolation_mode=interpolation_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'): + super(ToRGB, self).__init__() + self.upsample = upsample + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register(suffix='basicsr') +class StyleGAN2GeneratorBilinear(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + interpolation_mode='bilinear'): + super(StyleGAN2GeneratorBilinear, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + interpolation_mode=interpolation_mode)) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode)) + self.to_rgbs.append( + ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + bias=True, + activate=True, + interpolation_mode='bilinear'): + layers = [] + self.interpolation_mode = interpolation_mode + # downsample + if downsample: + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + layers.append( + torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners)) + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, + out_channels, + 3, + downsample=True, + interpolation_mode=interpolation_mode, + bias=True, + activate=True) + self.skip = ConvLayer( + in_channels, + out_channels, + 1, + downsample=True, + interpolation_mode=interpolation_mode, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3917fa2c7408e1f5b55b9930c643a9af920a4d81 --- /dev/null +++ b/basicsr/archs/swinir_arch.py @@ -0,0 +1,956 @@ +# Modified from https://github.com/JingyunLiang/SwinIR +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. + +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import to_2tuple, trunc_normal_ + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, n): + # calculate flops for 1 window with token length of n + flops = 0 + # qkv = self.qkv(x) + flops += n * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * n * (self.dim // self.num_heads) * n + # x = (attn @ v) + flops += self.num_heads * n * n * (self.dim // self.num_heads) + # x = self.proj(x) + flops += n * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}') + + def flops(self): + flops = 0 + h, w = self.input_resolution + # norm1 + flops += self.dim * h * w + # W-MSA/SW-MSA + nw = h * w / self.window_size / self.window_size + flops += nw * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * h * w + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, 'input feature has wrong size' + assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f'input_resolution={self.input_resolution}, dim={self.dim}' + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.dim + flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + h, w = self.input_resolution + flops += h * w * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + h, w = self.img_size + if self.norm is not None: + flops += h * w * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.num_feat * 3 * 9 + return flops + + +@ARCH_REGISTRY.register() +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1., + upsampler='', + resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # b seq_len c + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x + + def flops(self): + flops = 0 + h, w = self.patches_resolution + flops += h * w * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for layer in self.layers: + flops += layer.flops() + flops += h * w * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR( + upscale=2, + img_size=(height, width), + window_size=window_size, + img_range=1., + depths=[6, 6, 6, 6], + embed_dim=60, + num_heads=[6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/basicsr/archs/tof_arch.py b/basicsr/archs/tof_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a90a64d89386e19f92c987bbe2133472991d764a --- /dev/null +++ b/basicsr/archs/tof_arch.py @@ -0,0 +1,172 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import flow_warp + + +class BasicModule(nn.Module): + """Basic module of SPyNet. + + Note that unlike the architecture in spynet_arch.py, the basic module + here contains batch normalization. + """ + + def __init__(self): + super(BasicModule, self).__init__() + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(64), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(16), nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Estimated flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) + + +class SPyNetTOF(nn.Module): + """SPyNet architecture for TOF. + + Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use. + They differ in the following aspects: + + 1. The basic modules here contain BatchNorm. + 2. Normalization and denormalization are not done here, as they are done in TOFlow. + + ``Paper: Optical Flow Estimation using a Spatial Pyramid Network`` + + Reference: https://github.com/Coldog2333/pytoflow + + Args: + load_path (str): Path for pretrained SPyNet. Default: None. + """ + + def __init__(self, load_path=None): + super(SPyNetTOF, self).__init__() + + self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) + if load_path: + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + + def forward(self, ref, supp): + """ + Args: + ref (Tensor): Reference image with shape of (b, 3, h, w). + supp: The supporting image to be warped: (b, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (b, 2, h, w). + """ + num_batches, _, h, w = ref.size() + ref = [ref] + supp = [supp] + + # generate downsampled frames + for _ in range(3): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + # flow computation + flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) + for i in range(4): + flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + flow = flow_up + self.basic_module[i]( + torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) + return flow + + +@ARCH_REGISTRY.register() +class TOFlow(nn.Module): + """PyTorch implementation of TOFlow. + + In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames. + + ``Paper: Video Enhancement with Task-Oriented Flow`` + + Reference: https://github.com/anchen1011/toflow + + Reference: https://github.com/Coldog2333/pytoflow + + Args: + adapt_official_weights (bool): Whether to adapt the weights translated + from the official implementation. Set to false if you want to + train from scratch. Default: False + """ + + def __init__(self, adapt_official_weights=False): + super(TOFlow, self).__init__() + self.adapt_official_weights = adapt_official_weights + self.ref_idx = 0 if adapt_official_weights else 3 + + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + # flow estimation module + self.spynet = SPyNetTOF() + + # reconstruction module + self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) + self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4) + self.conv_3 = nn.Conv2d(64, 64, 1) + self.conv_4 = nn.Conv2d(64, 3, 1) + + # activation function + self.relu = nn.ReLU(inplace=True) + + def normalize(self, img): + return (img - self.mean) / self.std + + def denormalize(self, img): + return img * self.std + self.mean + + def forward(self, lrs): + """ + Args: + lrs: Input lr frames: (b, 7, 3, h, w). + + Returns: + Tensor: SR frame: (b, 3, h, w). + """ + # In the official implementation, the 0-th frame is the reference frame + if self.adapt_official_weights: + lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] + + num_batches, num_lrs, _, h, w = lrs.size() + + lrs = self.normalize(lrs.view(-1, 3, h, w)) + lrs = lrs.view(num_batches, num_lrs, 3, h, w) + + lr_ref = lrs[:, self.ref_idx, :, :, :] + lr_aligned = [] + for i in range(7): # 7 frames + if i == self.ref_idx: + lr_aligned.append(lr_ref) + else: + lr_supp = lrs[:, i, :, :, :] + flow = self.spynet(lr_ref, lr_supp) + lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) + + # reconstruction + hr = torch.stack(lr_aligned, dim=1) + hr = hr.view(num_batches, -1, h, w) + hr = self.relu(self.conv_1(hr)) + hr = self.relu(self.conv_2(hr)) + hr = self.relu(self.conv_3(hr)) + hr = self.conv_4(hr) + lr_ref + + return self.denormalize(hr) diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..05200334e477e59feefd1e4a0b5e94204e4eb2fa --- /dev/null +++ b/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..510df16771d153f61fbf2126baac24f69d3de7e4 --- /dev/null +++ b/basicsr/data/__init__.py @@ -0,0 +1,101 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/basicsr/data/__pycache__/__init__.cpython-310.pyc b/basicsr/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7f9cfdf882bfb9d0627d31ee728ed5c645d428 Binary files /dev/null and b/basicsr/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/data_sampler.cpython-310.pyc b/basicsr/data/__pycache__/data_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ac9e4ac3973c193e220e4459b7fcd8c549a654 Binary files /dev/null and b/basicsr/data/__pycache__/data_sampler.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/data_util.cpython-310.pyc b/basicsr/data/__pycache__/data_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ef9471e96a8317ae1a6a793e08cce784688d19 Binary files /dev/null and b/basicsr/data/__pycache__/data_util.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/degradations.cpython-310.pyc b/basicsr/data/__pycache__/degradations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c94bbd66b6b316e0eb35cd70c95a093c4556668 Binary files /dev/null and b/basicsr/data/__pycache__/degradations.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc b/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d263bfc6d1308ec9ca4dab514c17f1c8a9f31366 Binary files /dev/null and b/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc b/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d19b3e137ff62c17beea5d13ad04e8aa83d7d4 Binary files /dev/null and b/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc b/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4ab8b0f0b82e30c2f21ee78ccbcf3ef77be4783 Binary files /dev/null and b/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc b/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..196fd1a8987c762e39c2d77a36da5763b100410e Binary files /dev/null and b/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc b/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f7b511a2ac470286f8125434fcb10e4b7481b23 Binary files /dev/null and b/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc b/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd702193fbe239420cab4137136ad872d3882257 Binary files /dev/null and b/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc b/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc10848e313a329293fb3411f15d8844e962033a Binary files /dev/null and b/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc b/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f2b0117e6230bd9f45b78cfde824bf40a9e5bb Binary files /dev/null and b/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/transforms.cpython-310.pyc b/basicsr/data/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996a496d3c5ef61da47a8477358ef07712c93658 Binary files /dev/null and b/basicsr/data/__pycache__/transforms.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc b/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32bedb9ec89d091a5d0ec73dc2256e3991c23bb2 Binary files /dev/null and b/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc differ diff --git a/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc b/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db49fa38ec3d1855ce1058aca0175d71952640a2 Binary files /dev/null and b/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc differ diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dce2562fb9f99475c44e9185f50018a428859214 --- /dev/null +++ b/basicsr/data/data_util.py @@ -0,0 +1,362 @@ +import cv2 +import numpy as np +import torch +from os import path as osp +from torch.nn import functional as F + +from basicsr.data.transforms import mod_crop +from basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + return_imgname(bool): Whether return image names. Default False. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + list[str]: Returned image name list. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + + if return_imgname: + imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] + return imgs, imgnames + else: + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + :: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + +def paired_paths_from_meta_info_file_2(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + with open(meta_info_file, 'r') as fin: + input_names = [line.strip().split(' ')[1] for line in fin] + paths = [] + for i in range(len(gt_names)): + gt_name = gt_names[i] + lq_name = input_names[i] + basename, ext = osp.splitext(osp.basename(gt_name)) + basename = gt_name[:-len(ext)] + gt_path = osp.join(gt_folder, gt_name) + basename, ext = osp.splitext(osp.basename(lq_name)) + basename = lq_name[:-len(ext)] + input_path = osp.join(input_folder, lq_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.' + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py new file mode 100644 index 0000000000000000000000000000000000000000..5db40fb080908e9a0de503b9c9518710f89e2e0d --- /dev/null +++ b/basicsr/data/degradations.py @@ -0,0 +1,935 @@ +import cv2 +import math +import numpy as np +import random +import torch +from scipy import special +from scipy.stats import multivariate_normal +from torchvision.transforms.functional_tensor import rgb_to_grayscale + +# -------------------------------------------------------------------- # +# --------------------------- blur kernels --------------------------- # +# -------------------------------------------------------------------- # + + +# --------------------------- util functions --------------------------- # +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(d_matrix, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + + Args: + d_matrix (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, d_matrix) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a bivariate generalized Gaussian kernel. + + ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions`` + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a plateau-like anisotropic kernel. + + 1 / (1+x^(beta)) + + Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True, + return_sigma=False): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if not return_sigma: + return kernel + else: + return kernel, [sigma_x, sigma_y] + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True, + return_sigma=False): + """Randomly generate bivariate generalized Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # assume beta_range[0] < 1 < beta_range[1] + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if not return_sigma: + return kernel + else: + return kernel, [sigma_x, sigma_y] + + +def random_bivariate_plateau(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True, + return_sigma=False): + """Randomly generate bivariate plateau kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # TODO: this may be not proper + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + + if not return_sigma: + return kernel + else: + return kernel, [sigma_x, sigma_y] + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None, + return_sigma=False): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if not return_sigma: + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma) + elif kernel_type == 'generalized_iso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True, + return_sigma=return_sigma) + elif kernel_type == 'generalized_aniso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False, + return_sigma=return_sigma) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma) + return kernel + else: + if kernel_type == 'iso': + kernel, sigma_list = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma) + elif kernel_type == 'aniso': + kernel, sigma_list = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma) + elif kernel_type == 'generalized_iso': + kernel, sigma_list = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True, + return_sigma=return_sigma) + elif kernel_type == 'generalized_aniso': + kernel, sigma_list = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False, + return_sigma=return_sigma) + elif kernel_type == 'plateau_iso': + kernel, sigma_list = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma) + elif kernel_type == 'plateau_aniso': + kernel, sigma_list = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma) + return kernel, sigma_list + + +np.seterr(divide='ignore', invalid='ignore') + + +def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): + """2D sinc filter + + Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + + Args: + cutoff (float): cutoff frequency in radians (pi is max) + kernel_size (int): horizontal and vertical size, must be odd. + pad_to (int): pad kernel size to desired size, must be odd or zero. + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + kernel = np.fromfunction( + lambda x, y: cutoff * special.j1(cutoff * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) + kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) + kernel = kernel / np.sum(kernel) + if pad_to > kernel_size: + pad_size = (pad_to - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + return kernel + + +# ------------------------------------------------------------- # +# --------------------------- noise --------------------------- # +# ------------------------------------------------------------- # + +# ----------------------- Gaussian Noise ----------------------- # + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): + """Add Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_gaussian_noise(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if not isinstance(sigma, (float, int)): + sigma = sigma.view(img.size(0), 1, 1, 1) + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + + if cal_gray_noise: + noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. + noise_gray = noise_gray.view(b, 1, h, w) + + # always calculate color noise + noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. + + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + return noise + + +def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_gaussian_noise_pt(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Gaussian Noise ----------------------- # +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0, return_sigma=False): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + if return_sigma: + return generate_gaussian_noise(img, sigma, gray_noise), sigma + else: + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False, return_sigma=False): + if return_sigma: + noise, sigma = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma) + else: + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + if return_sigma: + return out, sigma + else: + return out + + +def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): + sigma = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_gaussian_noise_pt(img, sigma, gray_noise) + + +def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + +# ----------------------- Poisson (Shot) Noise ----------------------- # + + +def generate_poisson_noise(img, scale=1.0, gray_noise=False): + """Generate poisson noise. + + Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # round and clip image for counting vals correctly + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = len(np.unique(img)) + vals = 2**np.ceil(np.log2(vals)) + out = np.float32(np.random.poisson(img * vals) / float(vals)) + noise = out - img + if gray_noise: + noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) + return noise * scale + + +def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): + """Add poisson noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_poisson_noise(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): + """Generate a batch of poisson noise (PyTorch version) + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + if cal_gray_noise: + img_gray = rgb_to_grayscale(img, num_output_channels=1) + # round and clip image for counting vals correctly + img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img_gray * vals) / vals + noise_gray = out - img_gray + noise_gray = noise_gray.expand(b, 3, h, w) + + # always calculate color noise + # round and clip image for counting vals correctly + img = torch.clamp((img * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img * vals) / vals + noise = out - img + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + if not isinstance(scale, (float, int)): + scale = scale.view(b, 1, 1, 1) + return noise * scale + + +def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): + """Add poisson noise to a batch of images (PyTorch version). + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_poisson_noise_pt(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Poisson (Shot) Noise ----------------------- # + + +def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): + scale = np.random.uniform(scale_range[0], scale_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_poisson_noise(img, scale, gray_noise) + + +def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): + scale = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_poisson_noise_pt(img, scale, gray_noise) + + +def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + +# ----------------------- Random speckle Noise ----------------------- # + +def random_add_speckle_noise(imgs, speckle_std): + std_range = speckle_std + std_l = std_range[0] + std_r = std_range[1] + mean=0 + std=random.uniform(std_l/255.,std_r/255.) + + outputs = [] + for img in imgs: + gauss=np.random.normal(loc=mean,scale=std,size=img.shape) + noisy=img+gauss*img + noisy=np.clip(noisy,0,1).astype(np.float32) + + outputs.append(noisy) + + return outputs + + +def random_add_speckle_noise_pt(img, speckle_std): + std_range = speckle_std + std_l = std_range[0] + std_r = std_range[1] + mean=0 + std=random.uniform(std_l/255.,std_r/255.) + gauss=torch.normal(mean=mean,std=std,size=img.size()).to(img.device) + noisy=img+gauss*img + noisy=torch.clamp(noisy,0,1) + return noisy + +# ----------------------- Random saltpepper Noise ----------------------- # + +def random_add_saltpepper_noise(imgs, saltpepper_amount, saltpepper_svsp): + p_range = saltpepper_amount + p = random.uniform(p_range[0], p_range[1]) + q_range = saltpepper_svsp + q = random.uniform(q_range[0], q_range[1]) + + outputs = [] + for img in imgs: + out = img.copy() + flipped = np.random.choice([True, False], size=img.shape, + p=[p, 1 - p]) + salted = np.random.choice([True, False], size=img.shape, + p=[q, 1 - q]) + peppered = ~salted + out[flipped & salted] = 1 + out[flipped & peppered] = 0. + noisy = np.clip(out, 0, 1).astype(np.float32) + + outputs.append(noisy) + + return outputs + +def random_add_saltpepper_noise_pt(imgs, saltpepper_amount, saltpepper_svsp): + p_range = saltpepper_amount + p = random.uniform(p_range[0], p_range[1]) + q_range = saltpepper_svsp + q = random.uniform(q_range[0], q_range[1]) + + imgs = imgs.permute(0,2,3,1) + + outputs = [] + for i in range(imgs.size(0)): + img = imgs[i] + out = img.clone() + flipped = np.random.choice([True, False], size=img.shape, + p=[p, 1 - p]) + salted = np.random.choice([True, False], size=img.shape, + p=[q, 1 - q]) + peppered = ~salted + temp = flipped & salted + out[flipped & salted] = 1 + out[flipped & peppered] = 0. + noisy = torch.clamp(out, 0, 1) + + outputs.append(noisy.permute(2,0,1)) + if len(outputs)>1: + return torch.cat(outputs, dim=0) + else: + return outputs[0].unsqueeze(0) + +# ----------------------- Random screen Noise ----------------------- # + +def random_add_screen_noise(imgs, linewidth, space): + #screen_noise = np.random.uniform() < self.params['noise_prob'][0] + linewidth = linewidth + linewidth = int(np.random.uniform(linewidth[0], linewidth[1])) + space = space + space = int(np.random.uniform(space[0], space[1])) + center_color = [213,230,230] # RGB + outputs = [] + for img in imgs: + noise = img.copy() + + tmp_mask = np.zeros((img.shape[1], img.shape[0]), dtype=np.float32) + for i in range(0, img.shape[0], int((space+linewidth))): + tmp_mask[:, i:(i+linewidth)] = 1 + colour_masks = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.float32) + colour_masks[:,:,0] = (center_color[0] + np.random.uniform(-20, 20))/255. + colour_masks[:,:,1] = (center_color[1] + np.random.uniform(0, 20))/255. + colour_masks[:,:,2] = (center_color[2] + np.random.uniform(0, 20))/255. + noise_color = cv2.addWeighted(noise, 0.6, colour_masks, 0.4, 0.0) + noise = noise*(1-(tmp_mask[:,:,np.newaxis])) + noise_color*(tmp_mask[:,:,np.newaxis]) + + outputs.append(noise) + + return outputs + + +# ------------------------------------------------------------------------ # +# --------------------------- JPEG compression --------------------------- # +# ------------------------------------------------------------------------ # + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100), return_q=False): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = np.random.uniform(quality_range[0], quality_range[1]) + if return_q: + return add_jpg_compression(img, quality), quality + else: + return add_jpg_compression(img, quality) diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..23992eb877f6b7b46cf5f40ed3667fc10916269b --- /dev/null +++ b/basicsr/data/ffhq_dataset.py @@ -0,0 +1,80 @@ +import random +import time +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class FFHQDataset(data.Dataset): + """FFHQ dataset for StyleGAN. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + + """ + + def __init__(self, opt): + super(FFHQDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # FFHQ has 70000 images in total + self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + return {'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/ffhq_degradation_dataset.py b/basicsr/data/ffhq_degradation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d8c934ddd68816f2d2f0038191530056fe912c --- /dev/null +++ b/basicsr/data/ffhq_degradation_dataset.py @@ -0,0 +1,232 @@ +import cv2 +import math +import numpy as np +import os.path as osp +import torch +import torch.utils.data as data +import random +from basicsr.data import degradations as degradations +from basicsr.data.data_util import paths_from_folder +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from pathlib import Path +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) + +@DATASET_REGISTRY.register() +class FFHQDegradationDataset(data.Dataset): + """FFHQ dataset for GFPGAN. + It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + Please see more options in the codes. + """ + + def __init__(self, opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + if 'image_type' not in opt: + opt['image_type'] = 'png' + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['out_size'] + + self.crop_components = opt.get('crop_components', False) # facial components + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions + + if self.crop_components: + # load component list from a pre-process pth files + self.components_list = torch.load(opt.get('component_path')) + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend: scan file list from a folder + self.paths = self.paths = sorted([str(x) for x in Path(self.gt_folder).glob('*.'+opt['image_type'])]) + + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + # to gray + self.gray_prob = opt.get('gray_prob') + + logger = get_root_logger() + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + if self.color_jitter_shift is not None: + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] + + # get coordinates + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + h, w, _ = img_gt.shape + + # get facial component coordinates + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = degradations.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) + # noise + if self.noise_range is not None: + img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): # whether convert GT to gray images + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) + + # round and clip + img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. + + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_lq, self.mean, self.std, inplace=True) + + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41965cd159ec539aca3d60f5a5ccd84736e13d61 --- /dev/null +++ b/basicsr/data/paired_image_dataset.py @@ -0,0 +1,115 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file, paired_paths_from_meta_info_file_2 +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +import cv2 + + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + + 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. + 2. **meta_info_file**: Use meta information file to generate paths. \ + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. **folder**: Scan folders to generate paths. The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file_2([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + h, w = img_gt.shape[0:2] + # pad + if h < self.opt['gt_size'] or w < self.opt['gt_size']: + pad_h = max(0, self.opt['gt_size'] - h) + pad_w = max(0, self.opt['gt_size'] - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + img_lq = cv2.copyMakeBorder(img_lq, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] + img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] + + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..332abd32fcb004e6892d12dc69848a4454e3c503 --- /dev/null +++ b/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,122 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Reference: https://github.com/NVIDIA/apex/issues/304# + + It may consume more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/basicsr/data/realesrgan_dataset.py b/basicsr/data/realesrgan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7c0603d8353f5457b0dd96f9a9a876a192d113 --- /dev/null +++ b/basicsr/data/realesrgan_dataset.py @@ -0,0 +1,242 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from pathlib import Path +from torch.utils import data as data + +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANDataset(data.Dataset): + """Modified dataset based on the dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + if 'crop_size' in opt: + self.crop_size = opt['crop_size'] + else: + self.crop_size = 512 + if 'image_type' not in opt: + opt['image_type'] = 'png' + + # support multiple type of data: file path and meta data, remove support of lmdb + self.paths = [] + if 'meta_info' in opt: + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [v for v in paths] + if 'meta_num' in opt: + self.paths = sorted(self.paths)[:opt['meta_num']] + if 'gt_path' in opt: + if isinstance(opt['gt_path'], str): + self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])])) + else: + self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])])) + if len(opt['gt_path']) > 1: + for i in range(len(opt['gt_path'])-1): + self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])])) + if 'imagenet_path' in opt: + class_list = os.listdir(opt['imagenet_path']) + for class_file in class_list: + self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')])) + if 'face_gt_path' in opt: + if isinstance(opt['face_gt_path'], str): + face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])]) + self.paths.extend(face_list[:opt['num_face']]) + else: + face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])]) + self.paths.extend(face_list[:opt['num_face']]) + if len(opt['face_gt_path']) > 1: + for i in range(len(opt['face_gt_path'])-1): + self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']]) + + # limit number of pictures for test + if 'num_pic' in opt: + if 'val' or 'test' in opt: + random.shuffle(self.paths) + self.paths = self.paths[:opt['num_pic']] + else: + self.paths = self.paths[:opt['num_pic']] + + if 'mul_num' in opt: + self.paths = self.paths * opt['mul_num'] + # print('>>>>>>>>>>>>>>>>>>>>>') + # print(self.paths) + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + # logger = get_root_logger() + # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()-1) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + # filter the dataset and remove images with too low quality + img_size = os.path.getsize(gt_path) + img_size = img_size/1024 + + while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100: + index = random.randint(0, self.__len__()-1) + gt_path = self.paths[index] + + time.sleep(0.1) # sleep 1s for occasional server congestion + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + img_size = os.path.getsize(gt_path) + img_size = img_size/1024 + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = self.crop_size + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + # top = (h - crop_pad_size) // 2 -1 + # left = (w - crop_pad_size) // 2 -1 + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/realesrgan_paired_dataset.py b/basicsr/data/realesrgan_paired_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0c6159d448f26fc8a256d6a9d0c51096b78fe0 --- /dev/null +++ b/basicsr/data/realesrgan_paired_dataset.py @@ -0,0 +1,114 @@ +import os +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + + 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. + 2. **meta_info_file**: Use meta information file to generate paths. \ + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. **folder**: Scan folders to generate paths. The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + if 'num_pic' in self.opt: + self.paths = self.paths[:self.opt['num_pic']] + if 'phase' not in self.opt: + self.opt['phase'] = 'test' + if 'scale' not in self.opt: + self.opt['scale'] = 1 + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fabef1d7e80866888f3b57ecfeb4d97c93bcb5cd --- /dev/null +++ b/basicsr/data/reds_dataset.py @@ -0,0 +1,352 @@ +import numpy as np +import random +import torch +from pathlib import Path +from torch.utils import data as data + +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.flow_util import dequantize_flow +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class REDSDataset(data.Dataset): + """REDS dataset for training. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_REDS_GT.txt + + Each line contains: + 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by + a white space. + Examples: + 000 100 (720,1280,3) + 001 100 (720,1280,3) + ... + + Key examples: "000/00000000" + GT (gt): Ground-Truth; + LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + dataroot_flow (str, optional): Data root path for flow. + meta_info_file (str): Path for meta information file. + val_partition (str): Validation partition types. 'REDS4' or 'official'. + io_backend (dict): IO backend type and other kwarg. + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + interval_list (list): Interval list for temporal augmentation. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(REDSDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None + assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}') + self.num_frame = opt['num_frame'] + self.num_half_frames = opt['num_frame'] // 2 + + self.keys = [] + with open(opt['meta_info_file'], 'r') as fin: + for line in fin: + folder, frame_num, _ = line.split(' ') + self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))]) + + # remove the video clips used in validation + if opt['val_partition'] == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif opt['val_partition'] == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {opt["val_partition"]}.' + f"Supported ones are ['official', 'REDS4'].") + self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + if self.flow_root is not None: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] + else: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # temporal augmentation configs + self.interval_list = opt['interval_list'] + self.random_reverse = opt['random_reverse'] + interval_str = ','.join(str(x) for x in opt['interval_list']) + logger = get_root_logger() + logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' + f'random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip_name, frame_name = key.split('/') # key example: 000/00000000 + center_frame_idx = int(frame_name) + + # determine the neighboring frames + interval = random.choice(self.interval_list) + + # ensure not exceeding the borders + start_frame_idx = center_frame_idx - self.num_half_frames * interval + end_frame_idx = center_frame_idx + self.num_half_frames * interval + # each clip has 100 frames starting from 0 to 99 + while (start_frame_idx < 0) or (end_frame_idx > 99): + center_frame_idx = random.randint(0, 99) + start_frame_idx = (center_frame_idx - self.num_half_frames * interval) + end_frame_idx = center_frame_idx + self.num_half_frames * interval + frame_name = f'{center_frame_idx:08d}' + neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval)) + # random reverse + if self.random_reverse and random.random() < 0.5: + neighbor_list.reverse() + + assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}') + + # get the GT frame (as the center frame) + if self.is_lmdb: + img_gt_path = f'{clip_name}/{frame_name}' + else: + img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + # get the neighboring LQ frames + img_lqs = [] + for neighbor in neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip_name}/{neighbor:08d}' + else: + img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # get flows + if self.flow_root is not None: + img_flows = [] + # read previous flows + for i in range(self.num_half_frames, 0, -1): + if self.is_lmdb: + flow_path = f'{clip_name}/{frame_name}_p{i}' + else: + flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png') + img_bytes = self.file_client.get(flow_path, 'flow') + cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] + dx, dy = np.split(cat_flow, 2, axis=0) + flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. + img_flows.append(flow) + # read next flows + for i in range(1, self.num_half_frames + 1): + if self.is_lmdb: + flow_path = f'{clip_name}/{frame_name}_n{i}' + else: + flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png') + img_bytes = self.file_client.get(flow_path, 'flow') + cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255] + dx, dy = np.split(cat_flow, 2, axis=0) + flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here. + img_flows.append(flow) + + # for random crop, here, img_flows and img_lqs have the same + # spatial size + img_lqs.extend(img_flows) + + # randomly crop + img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) + if self.flow_root is not None: + img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:] + + # augmentation - flip, rotate + img_lqs.append(img_gt) + if self.flow_root is not None: + img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows) + else: + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[0:-1], dim=0) + img_gt = img_results[-1] + + if self.flow_root is not None: + img_flows = img2tensor(img_flows) + # add the zero center flow + img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) + img_flows = torch.stack(img_flows, dim=0) + + # img_lqs: (t, c, h, w) + # img_flows: (t, 2, h, w) + # img_gt: (c, h, w) + # key: str + if self.flow_root is not None: + return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key} + else: + return {'lq': img_lqs, 'gt': img_gt, 'key': key} + + def __len__(self): + return len(self.keys) + + +@DATASET_REGISTRY.register() +class REDSRecurrentDataset(data.Dataset): + """REDS dataset for training recurrent networks. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_REDS_GT.txt + + Each line contains: + 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by + a white space. + Examples: + 000 100 (720,1280,3) + 001 100 (720,1280,3) + ... + + Key examples: "000/00000000" + GT (gt): Ground-Truth; + LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + dataroot_flow (str, optional): Data root path for flow. + meta_info_file (str): Path for meta information file. + val_partition (str): Validation partition types. 'REDS4' or 'official'. + io_backend (dict): IO backend type and other kwarg. + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + interval_list (list): Interval list for temporal augmentation. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(REDSRecurrentDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + self.num_frame = opt['num_frame'] + + self.keys = [] + with open(opt['meta_info_file'], 'r') as fin: + for line in fin: + folder, frame_num, _ = line.split(' ') + self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))]) + + # remove the video clips used in validation + if opt['val_partition'] == 'REDS4': + val_partition = ['000', '011', '015', '020'] + elif opt['val_partition'] == 'official': + val_partition = [f'{v:03d}' for v in range(240, 270)] + else: + raise ValueError(f'Wrong validation partition {opt["val_partition"]}.' + f"Supported ones are ['official', 'REDS4'].") + if opt['test_mode']: + self.keys = [v for v in self.keys if v.split('/')[0] in val_partition] + else: + self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + if hasattr(self, 'flow_root') and self.flow_root is not None: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] + else: + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # temporal augmentation configs + self.interval_list = opt.get('interval_list', [1]) + self.random_reverse = opt.get('random_reverse', False) + interval_str = ','.join(str(x) for x in self.interval_list) + logger = get_root_logger() + logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' + f'random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip_name, frame_name = key.split('/') # key example: 000/00000000 + + # determine the neighboring frames + interval = random.choice(self.interval_list) + + # ensure not exceeding the borders + start_frame_idx = int(frame_name) + if start_frame_idx > 100 - self.num_frame * interval: + start_frame_idx = random.randint(0, 100 - self.num_frame * interval) + end_frame_idx = start_frame_idx + self.num_frame * interval + + neighbor_list = list(range(start_frame_idx, end_frame_idx, interval)) + + # random reverse + if self.random_reverse and random.random() < 0.5: + neighbor_list.reverse() + + # get the neighboring LQ and GT frames + img_lqs = [] + img_gts = [] + for neighbor in neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip_name}/{neighbor:08d}' + img_gt_path = f'{clip_name}/{neighbor:08d}' + else: + img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' + img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png' + + # get LQ + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # get GT + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + img_gts.append(img_gt) + + # randomly crop + img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.extend(img_gts) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0) + img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0) + + # img_lqs: (t, c, h, w) + # img_gts: (t, c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gts, 'key': key} + + def __len__(self): + return len(self.keys) diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d1a94d1723fb832b0c6fc897e72e0081c4a399 --- /dev/null +++ b/basicsr/data/single_image_dataset.py @@ -0,0 +1,164 @@ +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paths_from_lmdb +from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir +from basicsr.utils.registry import DATASET_REGISTRY + +from pathlib import Path +import random +import cv2 +import numpy as np +import torch + +@DATASET_REGISTRY.register() +class SingleImageDataset(data.Dataset): + """Read only lq images in the test phase. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). + + There are two modes: + 1. 'meta_info_file': Use meta information file to generate paths. + 2. 'folder': Scan folders to generate paths. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + """ + + def __init__(self, opt): + super(SingleImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + self.lq_folder = opt['dataroot_lq'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: + with open(self.opt['meta_info_file'], 'r') as fin: + self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] + else: + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load lq image + lq_path = self.paths[index] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + return {'lq': img_lq, 'lq_path': lq_path} + + def __len__(self): + return len(self.paths) + +@DATASET_REGISTRY.register() +class SingleImageNPDataset(data.Dataset): + """Read only lq images in the test phase. + + Read diffusion generated data for training CFW. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + gt_path: Data root path for training data. The path needs to contain the following folders: + gts: Ground-truth images. + inputs: Input LQ images. + latents: The corresponding HQ latent code generated by diffusion model given the input LQ image. + samples: The corresponding HQ image given the HQ latent code, just for verification. + io_backend (dict): IO backend type and other kwarg. + """ + + def __init__(self, opt): + super(SingleImageNPDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + if 'image_type' not in opt: + opt['image_type'] = 'png' + + if isinstance(opt['gt_path'], str): + self.gt_paths = sorted([str(x) for x in Path(opt['gt_path']+'/gts').glob('*.'+opt['image_type'])]) + self.lq_paths = sorted([str(x) for x in Path(opt['gt_path']+'/inputs').glob('*.'+opt['image_type'])]) + self.np_paths = sorted([str(x) for x in Path(opt['gt_path']+'/latents').glob('*.npy')]) + self.sample_paths = sorted([str(x) for x in Path(opt['gt_path']+'/samples').glob('*.'+opt['image_type'])]) + else: + self.gt_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/gts').glob('*.'+opt['image_type'])]) + self.lq_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/inputs').glob('*.'+opt['image_type'])]) + self.np_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/latents').glob('*.npy')]) + self.sample_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/samples').glob('*.'+opt['image_type'])]) + if len(opt['gt_path']) > 1: + for i in range(len(opt['gt_path'])-1): + self.gt_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/gts').glob('*.'+opt['image_type'])])) + self.lq_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/inputs').glob('*.'+opt['image_type'])])) + self.np_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/latents').glob('*.npy')])) + self.sample_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/samples').glob('*.'+opt['image_type'])])) + + assert len(self.gt_paths) == len(self.lq_paths) + assert len(self.gt_paths) == len(self.np_paths) + assert len(self.gt_paths) == len(self.sample_paths) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load lq image + lq_path = self.lq_paths[index] + gt_path = self.gt_paths[index] + sample_path = self.sample_paths[index] + np_path = self.np_paths[index] + + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + img_bytes_gt = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes_gt, float32=True) + + img_bytes_sample = self.file_client.get(sample_path, 'sample') + img_sample = imfrombytes(img_bytes_sample, float32=True) + + latent_np = np.load(np_path) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] + img_sample = rgb2ycbcr(img_sample, y_only=True)[..., None] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) + img_sample = img2tensor(img_sample, bgr2rgb=True, float32=True) + latent_np = torch.from_numpy(latent_np).float() + latent_np = latent_np.to(img_gt.device) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_sample, self.mean, self.std, inplace=True) + return {'lq': img_lq, 'lq_path': lq_path, 'gt': img_gt, 'gt_path': gt_path, 'latent': latent_np[0], 'latent_path': np_path, 'sample': img_sample, 'sample_path': sample_path} + + def __len__(self): + return len(self.gt_paths) diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c700a399bb737a2286ea705fcebd937e6fb54ca7 --- /dev/null +++ b/basicsr/data/transforms.py @@ -0,0 +1,240 @@ +import cv2 +import random +import torch + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + +def triplet_random_crop(img_gts, img_lqs, img_segs, gt_patch_size, scale, gt_path=None): + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + if not isinstance(img_segs, list): + img_segs = [img_segs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + h_seg, w_seg = img_segs[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + h_seg, w_seg = img_segs[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + + if input_type == 'Tensor': + img_segs = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_segs] + else: + img_segs = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_segs] + + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + if len(img_segs) == 1: + img_segs = img_segs[0] + + return img_gts, img_lqs, img_segs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..929f7d97472a0eb810e33e694d5362a6749ab4b6 --- /dev/null +++ b/basicsr/data/video_test_dataset.py @@ -0,0 +1,283 @@ +import glob +import torch +from os import path as osp +from torch.utils import data as data + +from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class VideoTestDataset(data.Dataset): + """Video test dataset. + + Supported datasets: Vid4, REDS4, REDSofficial. + More generally, it supports testing dataset with following structures: + + :: + + dataroot + ├── subfolder1 + ├── frame000 + ├── frame001 + ├── ... + ├── subfolder2 + ├── frame000 + ├── frame001 + ├── ... + ├── ... + + For testing datasets, there is no need to prepare LMDB files. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + io_backend (dict): IO backend type and other kwarg. + cache_data (bool): Whether to cache testing datasets. + name (str): Dataset name. + meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders + in the dataroot will be used. + num_frame (int): Window size for input frames. + padding (str): Padding mode. + """ + + def __init__(self, opt): + super(VideoTestDataset, self).__init__() + self.opt = opt + self.cache_data = opt['cache_data'] + self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] + self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' + + logger = get_root_logger() + logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') + self.imgs_lq, self.imgs_gt = {}, {} + if 'meta_info_file' in opt: + with open(opt['meta_info_file'], 'r') as fin: + subfolders = [line.split(' ')[0] for line in fin] + subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders] + subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders] + else: + subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*'))) + subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*'))) + + if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']: + for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt): + # get frame list for lq and gt + subfolder_name = osp.basename(subfolder_lq) + img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True))) + img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True))) + + max_idx = len(img_paths_lq) + assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})' + f' and gt folders ({len(img_paths_gt)})') + + self.data_info['lq_path'].extend(img_paths_lq) + self.data_info['gt_path'].extend(img_paths_gt) + self.data_info['folder'].extend([subfolder_name] * max_idx) + for i in range(max_idx): + self.data_info['idx'].append(f'{i}/{max_idx}') + border_l = [0] * max_idx + for i in range(self.opt['num_frame'] // 2): + border_l[i] = 1 + border_l[max_idx - i - 1] = 1 + self.data_info['border'].extend(border_l) + + # cache data or save the frame list + if self.cache_data: + logger.info(f'Cache {subfolder_name} for VideoTestDataset...') + self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) + self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) + else: + self.imgs_lq[subfolder_name] = img_paths_lq + self.imgs_gt[subfolder_name] = img_paths_gt + else: + raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}') + + def __getitem__(self, index): + folder = self.data_info['folder'][index] + idx, max_idx = self.data_info['idx'][index].split('/') + idx, max_idx = int(idx), int(max_idx) + border = self.data_info['border'][index] + lq_path = self.data_info['lq_path'][index] + + select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) + + if self.cache_data: + imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) + img_gt = self.imgs_gt[folder][idx] + else: + img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]]) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': folder, # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/99 + 'border': border, # 1 for border, 0 for non-border + 'lq_path': lq_path # center frame + } + + def __len__(self): + return len(self.data_info['gt_path']) + + +@DATASET_REGISTRY.register() +class VideoTestVimeo90KDataset(data.Dataset): + """Video test dataset for Vimeo90k-Test dataset. + + It only keeps the center frame for testing. + For testing datasets, there is no need to prepare LMDB files. + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + io_backend (dict): IO backend type and other kwarg. + cache_data (bool): Whether to cache testing datasets. + name (str): Dataset name. + meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders + in the dataroot will be used. + num_frame (int): Window size for input frames. + padding (str): Padding mode. + """ + + def __init__(self, opt): + super(VideoTestVimeo90KDataset, self).__init__() + self.opt = opt + self.cache_data = opt['cache_data'] + if self.cache_data: + raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.') + self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] + self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []} + neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.' + + logger = get_root_logger() + logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') + with open(opt['meta_info_file'], 'r') as fin: + subfolders = [line.split(' ')[0] for line in fin] + for idx, subfolder in enumerate(subfolders): + gt_path = osp.join(self.gt_root, subfolder, 'im4.png') + self.data_info['gt_path'].append(gt_path) + lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list] + self.data_info['lq_path'].append(lq_paths) + self.data_info['folder'].append('vimeo90k') + self.data_info['idx'].append(f'{idx}/{len(subfolders)}') + self.data_info['border'].append(0) + + def __getitem__(self, index): + lq_path = self.data_info['lq_path'][index] + gt_path = self.data_info['gt_path'][index] + imgs_lq = read_img_seq(lq_path) + img_gt = read_img_seq([gt_path]) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': self.data_info['folder'][index], # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/843 + 'border': self.data_info['border'][index], # 0 for non-border + 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame + } + + def __len__(self): + return len(self.data_info['gt_path']) + + +@DATASET_REGISTRY.register() +class VideoTestDUFDataset(VideoTestDataset): + """ Video test dataset for DUF dataset. + + Args: + opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset. + It has the following extra keys: + use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames. + scale (bool): Scale, which will be added automatically. + """ + + def __getitem__(self, index): + folder = self.data_info['folder'][index] + idx, max_idx = self.data_info['idx'][index].split('/') + idx, max_idx = int(idx), int(max_idx) + border = self.data_info['border'][index] + lq_path = self.data_info['lq_path'][index] + + select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) + + if self.cache_data: + if self.opt['use_duf_downsampling']: + # read imgs_gt to generate low-resolution frames + imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx)) + imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) + else: + imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx)) + img_gt = self.imgs_gt[folder][idx] + else: + if self.opt['use_duf_downsampling']: + img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] + # read imgs_gt to generate low-resolution frames + imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) + imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale']) + else: + img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale']) + img_gt.squeeze_(0) + + return { + 'lq': imgs_lq, # (t, c, h, w) + 'gt': img_gt, # (c, h, w) + 'folder': folder, # folder name + 'idx': self.data_info['idx'][index], # e.g., 0/99 + 'border': border, # 1 for border, 0 for non-border + 'lq_path': lq_path # center frame + } + + +@DATASET_REGISTRY.register() +class VideoRecurrentTestDataset(VideoTestDataset): + """Video test dataset for recurrent architectures, which takes LR video + frames as input and output corresponding HR video frames. + + Args: + opt (dict): Same as VideoTestDataset. Unused opt: + padding (str): Padding mode. + + """ + + def __init__(self, opt): + super(VideoRecurrentTestDataset, self).__init__(opt) + # Find unique folder strings + self.folders = sorted(list(set(self.data_info['folder']))) + + def __getitem__(self, index): + folder = self.folders[index] + + if self.cache_data: + imgs_lq = self.imgs_lq[folder] + imgs_gt = self.imgs_gt[folder] + else: + raise NotImplementedError('Without cache_data is not implemented.') + + return { + 'lq': imgs_lq, + 'gt': imgs_gt, + 'folder': folder, + } + + def __len__(self): + return len(self.folders) diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e33e1082667aeee61fecf2436fb287e82e0936 --- /dev/null +++ b/basicsr/data/vimeo90k_dataset.py @@ -0,0 +1,199 @@ +import random +import torch +from pathlib import Path +from torch.utils import data as data + +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class Vimeo90KDataset(data.Dataset): + """Vimeo90K dataset for training. + + The keys are generated from a meta info txt file. + basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt + + Each line contains the following items, separated by a white space. + + 1. clip name; + 2. frame number; + 3. image shape + + Examples: + + :: + + 00001/0001 7 (256,448,3) + 00001/0002 7 (256,448,3) + + - Key examples: "00001/0001" + - GT (gt): Ground-Truth; + - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. + + The neighboring frame list for different num_frame: + + :: + + num_frame | frame list + 1 | 4 + 3 | 3,4,5 + 5 | 2,3,4,5,6 + 7 | 1,2,3,4,5,6,7 + + Args: + opt (dict): Config for train dataset. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + num_frame (int): Window size for input frames. + gt_size (int): Cropped patched size for gt patches. + random_reverse (bool): Random reverse input frames. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + scale (bool): Scale, which will be added automatically. + """ + + def __init__(self, opt): + super(Vimeo90KDataset, self).__init__() + self.opt = opt + self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) + + with open(opt['meta_info_file'], 'r') as fin: + self.keys = [line.split(' ')[0] for line in fin] + + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.is_lmdb = False + if self.io_backend_opt['type'] == 'lmdb': + self.is_lmdb = True + self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + + # indices of input images + self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] + + # temporal augmentation configs + self.random_reverse = opt['random_reverse'] + logger = get_root_logger() + logger.info(f'Random reverse is {self.random_reverse}.') + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # random reverse + if self.random_reverse and random.random() < 0.5: + self.neighbor_list.reverse() + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip, seq = key.split('/') # key example: 00001/0001 + + # get the GT frame (im4.png) + if self.is_lmdb: + img_gt_path = f'{key}/im4' + else: + img_gt_path = self.gt_root / clip / seq / 'im4.png' + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + # get the neighboring LQ frames + img_lqs = [] + for neighbor in self.neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip}/{seq}/im{neighbor}' + else: + img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + img_lqs.append(img_lq) + + # randomly crop + img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.append(img_gt) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[0:-1], dim=0) + img_gt = img_results[-1] + + # img_lqs: (t, c, h, w) + # img_gt: (c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gt, 'key': key} + + def __len__(self): + return len(self.keys) + + +@DATASET_REGISTRY.register() +class Vimeo90KRecurrentDataset(Vimeo90KDataset): + + def __init__(self, opt): + super(Vimeo90KRecurrentDataset, self).__init__(opt) + + self.flip_sequence = opt['flip_sequence'] + self.neighbor_list = [1, 2, 3, 4, 5, 6, 7] + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # random reverse + if self.random_reverse and random.random() < 0.5: + self.neighbor_list.reverse() + + scale = self.opt['scale'] + gt_size = self.opt['gt_size'] + key = self.keys[index] + clip, seq = key.split('/') # key example: 00001/0001 + + # get the neighboring LQ and GT frames + img_lqs = [] + img_gts = [] + for neighbor in self.neighbor_list: + if self.is_lmdb: + img_lq_path = f'{clip}/{seq}/im{neighbor}' + img_gt_path = f'{clip}/{seq}/im{neighbor}' + else: + img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' + img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png' + # LQ + img_bytes = self.file_client.get(img_lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + # GT + img_bytes = self.file_client.get(img_gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + + img_lqs.append(img_lq) + img_gts.append(img_gt) + + # randomly crop + img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) + + # augmentation - flip, rotate + img_lqs.extend(img_gts) + img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) + + img_results = img2tensor(img_results) + img_lqs = torch.stack(img_results[:7], dim=0) + img_gts = torch.stack(img_results[7:], dim=0) + + if self.flip_sequence: # flip the sequence: 7 frames to 14 frames + img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) + img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) + + # img_lqs: (t, c, h, w) + # img_gt: (c, h, w) + # key: str + return {'lq': img_lqs, 'gt': img_gts, 'key': key} + + def __len__(self): + return len(self.keys) diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70a172aeed5b388ae102466eb1f02d40ba30e9b4 --- /dev/null +++ b/basicsr/losses/__init__.py @@ -0,0 +1,31 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import LOSS_REGISTRY +from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty + +__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] + +# automatically scan and import loss modules for registry +# scan all the files under the 'losses' folder and collect files ending with '_loss.py' +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/basicsr/losses/__pycache__/__init__.cpython-310.pyc b/basicsr/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45eab683236c4c18b5b7113cab433c8ab99d3b80 Binary files /dev/null and b/basicsr/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc b/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4d7b64853935dfecbe9502acf59fc4d9bf32a8b Binary files /dev/null and b/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc differ diff --git a/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc b/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56923128d39ab39a1a0c7b81a43d869e83cd0ac2 Binary files /dev/null and b/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc differ diff --git a/basicsr/losses/__pycache__/loss_util.cpython-310.pyc b/basicsr/losses/__pycache__/loss_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cd31179ff2c3d240c340b2f5f004d76002c2103 Binary files /dev/null and b/basicsr/losses/__pycache__/loss_util.cpython-310.pyc differ diff --git a/basicsr/losses/basic_loss.py b/basicsr/losses/basic_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e965526a9b0e2686575bf93f0173cc2664d9bb --- /dev/null +++ b/basicsr/losses/basic_loss.py @@ -0,0 +1,253 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.archs.vgg_arch import VGGFeatureExtractor +from basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/basicsr/losses/gan_loss.py b/basicsr/losses/gan_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..870baa2227b79eab29a3141a216b4b614e2bcdf3 --- /dev/null +++ b/basicsr/losses/gan_loss.py @@ -0,0 +1,207 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Reference: Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fd293ff9e6a22814e5aeff6ae11fb54d2e4bafff --- /dev/null +++ b/basicsr/losses/loss_util.py @@ -0,0 +1,145 @@ +import functools +import torch +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper + + +def get_local_weights(residual, ksize): + """Get local weights for generating the artifact map of LDL. + + It is only called by the `get_refined_artifact_map` function. + + Args: + residual (Tensor): Residual between predicted and ground truth images. + ksize (Int): size of the local window. + + Returns: + Tensor: weight for each pixel to be discriminated as an artifact pixel + """ + + pad = (ksize - 1) // 2 + residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') + + unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) + pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) + + return pixel_level_weight + + +def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): + """Calculate the artifact map of LDL + (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) + + Args: + img_gt (Tensor): ground truth images. + img_output (Tensor): output images given by the optimizing model. + img_ema (Tensor): output images given by the ema model. + ksize (Int): size of the local window. + + Returns: + overall_weight: weight for each pixel to be discriminated as an artifact pixel + (calculated based on both local and global observations). + """ + + residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) + residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) + + patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) + pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) + overall_weight = patch_level_weight * pixel_level_weight + + overall_weight[residual_sr < residual_ema] = 0 + + return overall_weight diff --git a/basicsr/metrics/README.md b/basicsr/metrics/README.md new file mode 100644 index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d --- /dev/null +++ b/basicsr/metrics/README.md @@ -0,0 +1,48 @@ +# Metrics + +[English](README.md) **|** [简体中文](README_CN.md) + +- [约定](#约定) +- [PSNR 和 SSIM](#psnr-和-ssim) + +## 约定 + +因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: + +- Numpy 类型 (一般是 cv2 的结果) + - UINT8: BGR, [0, 255], (h, w, c) + - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 +- Tensor 类型 + - float: RGB, [0, 1], (n, c, h, w) + +其他约定: + +- 以 `_pt` 结尾的是 PyTorch 结果 +- PyTorch version 支持 batch 计算 +- 颜色转换在 float32 上做;metric计算在 float64 上做 + +## PSNR 和 SSIM + +PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 +在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) + +下面列了各个实现的结果比对. +总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 + +- PSNR 比对 + +|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | +|:---| :---: | :---: | :---: | :---: | :---: | +|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | +|baboon| Y | - |22.441898 | 22.441899 | 22.444916| +|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | +|comic | Y | - | 21.720398 | 21.720398 | 21.721663| + +- SSIM 比对 + +|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | +|:---| :---: | :---: | :---: | :---: | :---: | +|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | +|baboon| Y | - |0.453097| 0.453097 | 0.453171| +|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| +|comic | Y | - | 0.585511 | 0.585511 | 0.585522 | diff --git a/basicsr/metrics/README_CN.md b/basicsr/metrics/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d --- /dev/null +++ b/basicsr/metrics/README_CN.md @@ -0,0 +1,48 @@ +# Metrics + +[English](README.md) **|** [简体中文](README_CN.md) + +- [约定](#约定) +- [PSNR 和 SSIM](#psnr-和-ssim) + +## 约定 + +因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: + +- Numpy 类型 (一般是 cv2 的结果) + - UINT8: BGR, [0, 255], (h, w, c) + - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 +- Tensor 类型 + - float: RGB, [0, 1], (n, c, h, w) + +其他约定: + +- 以 `_pt` 结尾的是 PyTorch 结果 +- PyTorch version 支持 batch 计算 +- 颜色转换在 float32 上做;metric计算在 float64 上做 + +## PSNR 和 SSIM + +PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 +在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) + +下面列了各个实现的结果比对. +总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 + +- PSNR 比对 + +|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | +|:---| :---: | :---: | :---: | :---: | :---: | +|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | +|baboon| Y | - |22.441898 | 22.441899 | 22.444916| +|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | +|comic | Y | - | 21.720398 | 21.720398 | 21.721663| + +- SSIM 比对 + +|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | +|:---| :---: | :---: | :---: | :---: | :---: | +|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | +|baboon| Y | - |0.453097| 0.453097 | 0.453171| +|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| +|comic | Y | - | 0.585511 | 0.585511 | 0.585522 | diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..330f3c863f66a98d41942c6995837283265d94ef --- /dev/null +++ b/basicsr/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from basicsr.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_pt, calculate_psnr_pt + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/basicsr/metrics/__pycache__/__init__.cpython-310.pyc b/basicsr/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4d597fb39f26d9f4379800a6de6eb22cf7b6ee0 Binary files /dev/null and b/basicsr/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc b/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdfa57ae90e81ef558cd76a4811a95b3e9af1290 Binary files /dev/null and b/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc differ diff --git a/basicsr/metrics/__pycache__/niqe.cpython-310.pyc b/basicsr/metrics/__pycache__/niqe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3564457c955f01d0f9b0bc87df528c9c1c928d34 Binary files /dev/null and b/basicsr/metrics/__pycache__/niqe.cpython-310.pyc differ diff --git a/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc b/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5772df06bc1e353cbd85ecf3707b1cca3fcb1f95 Binary files /dev/null and b/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc differ diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0ba6df1de96d93a60c1cfd3dc1fcf4d3d31533 --- /dev/null +++ b/basicsr/metrics/fid.py @@ -0,0 +1,89 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an representative data set. + sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2a27c70a043beeeb59cfaf533079492293065448 --- /dev/null +++ b/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from basicsr.utils import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c1467f61d809ec3b2630073118460d9d61a861 --- /dev/null +++ b/basicsr/metrics/niqe.py @@ -0,0 +1,199 @@ +import cv2 +import math +import numpy as np +import os +from scipy.ndimage import convolve +from scipy.special import gamma + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.matlab_functions import imresize +from basicsr.utils.registry import METRIC_REGISTRY + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + ``Paper: Making a "Completely Blind" Image Quality Analyzer`` + + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implementation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + + if scale == 1: + img = imresize(img / 255., scale=0.5, antialiasing=True) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) + return quality + + +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + ``Paper: Making a "Completely Blind" Image Quality Analyzer`` + + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + # round is necessary for being consistent with MATLAB's result + img = img.round() + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/basicsr/metrics/niqe_pris_params.npz b/basicsr/metrics/niqe_pris_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..42f06a9a18e6ed8bbf7933bec1477b189ef798de --- /dev/null +++ b/basicsr/metrics/niqe_pris_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296 +size 11850 diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ab03113f89805c990ff22795601274bf45db23a1 --- /dev/null +++ b/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,231 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.color_util import rgb2ycbcr_pt +from basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + ``Paper: Image quality assessment: From error visibility to structural similarity`` + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + ``Paper: Image quality assessment: From error visibility to structural similarity`` + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) diff --git a/basicsr/metrics/test_metrics/test_psnr_ssim.py b/basicsr/metrics/test_metrics/test_psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..18b05a73a0e38e89b2321ddc9415123a92f5c5a4 --- /dev/null +++ b/basicsr/metrics/test_metrics/test_psnr_ssim.py @@ -0,0 +1,52 @@ +import cv2 +import torch + +from basicsr.metrics import calculate_psnr, calculate_ssim +from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt +from basicsr.utils import img2tensor + + +def test(img_path, img_path2, crop_border, test_y_channel=False): + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) + + # --------------------- Numpy --------------------- + psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') + + # --------------------- PyTorch (CPU) --------------------- + img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + # --------------------- PyTorch (GPU) --------------------- + img = img.cuda() + img2 = img2.cuda() + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + psnr_pth = calculate_psnr_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' + f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') + + +if __name__ == '__main__': + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) + + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85796deae014c20a9aa600133468d04900c4fb89 --- /dev/null +++ b/basicsr/models/__init__.py @@ -0,0 +1,29 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/basicsr/models/__pycache__/__init__.cpython-310.pyc b/basicsr/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba727bf1df1871606d58e9ed3fb09d78dd475c75 Binary files /dev/null and b/basicsr/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/base_model.cpython-310.pyc b/basicsr/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0bac3c7f0da27060809f0e360e906c1641456b6 Binary files /dev/null and b/basicsr/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/edvr_model.cpython-310.pyc b/basicsr/models/__pycache__/edvr_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a2fd412f040e103fc453e3d4fcdb2bee147a26 Binary files /dev/null and b/basicsr/models/__pycache__/edvr_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc b/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a29bab0dec8f42183211c68c41703c6a21054266 Binary files /dev/null and b/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc b/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c76fa597dbfcccee32d3497130bcd89feca35e72 Binary files /dev/null and b/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc b/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa57c5591728845ca57b616e0d170dd78f08159 Binary files /dev/null and b/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc b/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87e0fb55748a1f84aa1a6a0972b0b0413cbd9d1 Binary files /dev/null and b/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc b/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ede21c212796fa90149610073d5b3ce04529f1c Binary files /dev/null and b/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/sr_model.cpython-310.pyc b/basicsr/models/__pycache__/sr_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d3b40c61913eb7fdb946423281da584a58e8fb3 Binary files /dev/null and b/basicsr/models/__pycache__/sr_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/srgan_model.cpython-310.pyc b/basicsr/models/__pycache__/srgan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ab82228e4acdd5bba78db05feacd793e2be2d81 Binary files /dev/null and b/basicsr/models/__pycache__/srgan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc b/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f083ad65878672a0631fcc943c46b5afa066f03 Binary files /dev/null and b/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/swinir_model.cpython-310.pyc b/basicsr/models/__pycache__/swinir_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9c20da06300db42bec0ea2827d8dec178e9591 Binary files /dev/null and b/basicsr/models/__pycache__/swinir_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/video_base_model.cpython-310.pyc b/basicsr/models/__pycache__/video_base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e669db1dff4e1644abee687b1df84514756c2437 Binary files /dev/null and b/basicsr/models/__pycache__/video_base_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc b/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a48bd51a9bd87a6b9fbb5bcc25063e4059c6b87 Binary files /dev/null and b/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc b/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dbefc3f99e6734de5b1bed5d79d00c58129655c Binary files /dev/null and b/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc differ diff --git a/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc b/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..209d9ff667bd722f36273ee01d3798544375607c Binary files /dev/null and b/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc differ diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf8229f59dee86a7f9f95c1d07da785fb5f15b3 --- /dev/null +++ b/basicsr/models/base_model.py @@ -0,0 +1,392 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils import get_root_logger +from basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + elif optim_type == 'AdamW': + optimizer = torch.optim.AdamW(params, lr, **kwargs) + elif optim_type == 'Adamax': + optimizer = torch.optim.Adamax(params, lr, **kwargs) + elif optim_type == 'SGD': + optimizer = torch.optim.SGD(params, lr, **kwargs) + elif optim_type == 'ASGD': + optimizer = torch.optim.ASGD(params, lr, **kwargs) + elif optim_type == 'RMSprop': + optimizer = torch.optim.RMSprop(params, lr, **kwargs) + elif optim_type == 'Rprop': + optimizer = torch.optim.Rprop(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supported yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warm-up. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warm-up iter numbers. -1 for no warm-up. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/basicsr/models/edvr_model.py b/basicsr/models/edvr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdbf7b94fe3f06c76fbf2a4941621f64e0003e7 --- /dev/null +++ b/basicsr/models/edvr_model.py @@ -0,0 +1,62 @@ +from basicsr.utils import get_root_logger +from basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class EDVRModel(VideoBaseModel): + """EDVR Model. + + Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 + """ + + def __init__(self, opt): + super(EDVRModel, self).__init__(opt) + if self.is_train: + self.train_tsa_iter = opt['train'].get('tsa_iter') + + def setup_optimizers(self): + train_opt = self.opt['train'] + dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') + if dcn_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate dcn params and normal params for different lr + normal_params = [] + dcn_params = [] + for name, param in self.net_g.named_parameters(): + if 'dcn' in name: + dcn_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': dcn_params, + 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.train_tsa_iter: + if current_iter == 1: + logger = get_root_logger() + logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'fusion' not in name: + param.requires_grad = False + elif current_iter == self.train_tsa_iter: + logger = get_root_logger() + logger.warning('Train all the parameters.') + for param in self.net_g.parameters(): + param.requires_grad = True + + super(EDVRModel, self).optimize_parameters(current_iter) diff --git a/basicsr/models/esrgan_model.py b/basicsr/models/esrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d746d0e29418d9e8f35fa9c1e3a315d694075be --- /dev/null +++ b/basicsr/models/esrgan_model.py @@ -0,0 +1,83 @@ +import torch +from collections import OrderedDict + +from basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel + + +@MODEL_REGISTRY.register() +class ESRGANModel(SRGANModel): + """ESRGAN model for single image super-resolution.""" + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss (relativistic gan) + real_d_pred = self.net_d(self.gt).detach() + fake_g_pred = self.net_d(self.output) + l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) + l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # gan loss (relativistic gan) + + # In order to avoid the error in distributed training: + # "Error detected in CudnnBatchNormBackward: RuntimeError: one of + # the variables needed for gradient computation has been modified by + # an inplace operation", + # we separate the backwards for real and fake, and also detach the + # tensor for calculating mean. + + # real + fake_d_pred = self.net_d(self.output).detach() + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 + l_d_fake.backward() + self.optimizer_d.step() + + loss_dict['l_d_real'] = l_d_real + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) diff --git a/basicsr/models/hifacegan_model.py b/basicsr/models/hifacegan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..435a2b179d6b7c670fe96a83ce45b461300b2c89 --- /dev/null +++ b/basicsr/models/hifacegan_model.py @@ -0,0 +1,288 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class HiFaceGANModel(SRModel): + """HiFaceGAN model for generic-purpose face restoration. + No prior modeling required, works for any degradations. + Currently doesn't support EMA for inference. + """ + + def init_training_settings(self): + + train_opt = self.opt['train'] + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass')) + + self.net_g.train() + + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # define losses + # HiFaceGAN does not use pixel loss by default + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('feature_matching_opt'): + self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device) + else: + self.cri_feat = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def discriminate(self, input_lq, output, ground_truth): + """ + This is a conditional (on the input) discriminator + In Batch Normalization, the fake and real images are + recommended to be in the same batch to avoid disparate + statistics in fake and real images. + So both fake and real images are fed to D all at once. + """ + h, w = output.shape[-2:] + if output.shape[-2:] != input_lq.shape[-2:]: + lq = torch.nn.functional.interpolate(input_lq, (h, w)) + real = torch.nn.functional.interpolate(ground_truth, (h, w)) + fake_concat = torch.cat([lq, output], dim=1) + real_concat = torch.cat([lq, real], dim=1) + else: + fake_concat = torch.cat([input_lq, output], dim=1) + real_concat = torch.cat([input_lq, ground_truth], dim=1) + + fake_and_real = torch.cat([fake_concat, real_concat], dim=0) + discriminator_out = self.net_d(fake_and_real) + pred_fake, pred_real = self._divide_pred(discriminator_out) + return pred_fake, pred_real + + @staticmethod + def _divide_pred(pred): + """ + Take the prediction of fake and real images from the combined batch. + The prediction contains the intermediate outputs of multiscale GAN, + so it's usually a list + """ + if type(pred) == list: + fake = [] + real = [] + for p in pred: + fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) + real.append([tensor[tensor.size(0) // 2:] for tensor in p]) + else: + fake = pred[:pred.size(0) // 2] + real = pred[pred.size(0) // 2:] + + return fake, real + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # Requires real prediction for feature matching loss + pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt) + l_g_gan = self.cri_gan(pred_fake, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # feature matching loss + if self.cri_feat: + l_g_feat = self.cri_feat(pred_fake, pred_real) + l_g_total += l_g_feat + loss_dict['l_g_feat'] = l_g_feat + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # TODO: Benchmark test between HiFaceGAN and SRGAN implementation: + # SRGAN use the same fake output for discriminator update + # while HiFaceGAN regenerate a new output using updated net_g + # This should not make too much difference though. Stick to SRGAN now. + # ------------------------------------------------------------------- + # ---------- Below are original HiFaceGAN code snippet -------------- + # ------------------------------------------------------------------- + # with torch.no_grad(): + # fake_image = self.net_g(self.lq) + # fake_image = fake_image.detach() + # fake_image.requires_grad_() + # pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt) + + # real + pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt) + l_d_real = self.cri_gan(pred_real, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + # fake + l_d_fake = self.cri_gan(pred_fake, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + + l_d_total = (l_d_real + l_d_fake) / 2 + l_d_total.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + print('HiFaceGAN does not support EMA now. pass') + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """ + Warning: HiFaceGAN requires train() mode even for validation + For more info, see https://github.com/Lotayou/Face-Renovation/issues/31 + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + + if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'): + self.net_g.train() + + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + print('In HiFaceGANModel: The new metrics package is under development.' + + 'Using super method now (Only PSNR & SSIM are supported)') + super().nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + """ + TODO: Validation using updated metric system + The metrics are now evaluated after all images have been tested + This allows batch processing, and also allows evaluation of + distributional metrics, such as: + + @ Frechet Inception Distance: FID + @ Maximum Mean Discrepancy: MMD + + Warning: + Need careful batch management for different inference settings. + + """ + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()} + sr_tensors = [] + gt_tensors = [] + + pbar = tqdm(total=len(dataloader), unit='image') + for val_data in dataloader: + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze + sr_tensors.append(visuals['result']) + if 'gt' in visuals: + gt_tensors.append(visuals['gt']) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + + imwrite(tensor2img(visuals['result']), save_img_path) + + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + sr_pack = torch.cat(sr_tensors, dim=0) + gt_pack = torch.cat(gt_tensors, dim=0) + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + # The new metric caller automatically returns mean value + # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run + self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_) + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + print('HiFaceGAN does not support EMA now. Fallback to normal mode.') + + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf --- /dev/null +++ b/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c74b28fb1dc6a7f5c5ad3f7d8bb96c19c52ee92b --- /dev/null +++ b/basicsr/models/realesrgan_model.py @@ -0,0 +1,267 @@ +import numpy as np +import random +import torch +from collections import OrderedDict +from torch.nn import functional as F + +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.losses.loss_util import get_refined_artifact_map +from basicsr.models.srgan_model import SRGANModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + if self.cri_ldl: + self.output_ema = self.net_g_ema(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + if self.cri_ldl: + pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7) + l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt)) + l_g_total += l_g_ldl + loss_dict['l_g_ldl'] = l_g_ldl + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f5790918b969682a0db0e2ed9236b7046d627b90 --- /dev/null +++ b/basicsr/models/realesrnet_model.py @@ -0,0 +1,189 @@ +import numpy as np +import random +import torch +from torch.nn import functional as F + +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.sr_model import SRModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..787f1fd2eab5963579c764c1bfb87199b7dd196f --- /dev/null +++ b/basicsr/models/sr_model.py @@ -0,0 +1,279 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def test_selfensemble(self): + # TODO: to be tested + # 8 augmentations + # modified from https://github.com/thstkdgus35/EDSR-PyTorch + + def _transform(v, op): + # if self.precision != 'single': v = v.float() + v2np = v.data.cpu().numpy() + if op == 'v': + tfnp = v2np[:, :, :, ::-1].copy() + elif op == 'h': + tfnp = v2np[:, :, ::-1, :].copy() + elif op == 't': + tfnp = v2np.transpose((0, 1, 3, 2)).copy() + + ret = torch.Tensor(tfnp).to(self.device) + # if self.precision == 'half': ret = ret.half() + + return ret + + # prepare augmented data + lq_list = [self.lq] + for tf in 'v', 'h', 't': + lq_list.extend([_transform(t, tf) for t in lq_list]) + + # inference + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + out_list = [self.net_g_ema(aug) for aug in lq_list] + else: + self.net_g.eval() + with torch.no_grad(): + out_list = [self.net_g_ema(aug) for aug in lq_list] + self.net_g.train() + + # merge results + for i in range(len(out_list)): + if i > 3: + out_list[i] = _transform(out_list[i], 't') + if i % 4 > 1: + out_list[i] = _transform(out_list[i], 'h') + if (i % 4) % 2 == 1: + out_list[i] = _transform(out_list[i], 'v') + output = torch.cat(out_list, dim=0) + + self.output = output.mean(dim=0, keepdim=True) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..45387ca7908e3f38f59a605adb8242ad12fcf1a1 --- /dev/null +++ b/basicsr/models/srgan_model.py @@ -0,0 +1,149 @@ +import torch +from collections import OrderedDict + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.utils import get_root_logger +from basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SRGANModel(SRModel): + """SRGAN model for single image super-resolution.""" + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('ldl_opt'): + self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) + else: + self.cri_ldl = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7da708122160f2be51a98a6a635349f34ee042e --- /dev/null +++ b/basicsr/models/stylegan2_model.py @@ -0,0 +1,283 @@ +import cv2 +import math +import numpy as np +import random +import torch +from collections import OrderedDict +from os import path as osp + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.losses.gan_loss import g_path_regularize, r1_penalty +from basicsr.utils import imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class StyleGAN2Model(BaseModel): + """StyleGAN2 model.""" + + def __init__(self, opt): + super(StyleGAN2Model, self).__init__(opt) + + # define network net_g + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + # latent dimension: self.num_style_feat + self.num_style_feat = opt['network_g']['num_style_feat'] + num_val_samples = self.opt['val'].get('num_val_samples', 16) + self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving, do not need to + # wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # define losses + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.path_reg_weight = train_opt['path_reg_weight'] # for generator + + self.net_g_reg_every = train_opt['net_g_reg_every'] + self.net_d_reg_every = train_opt['net_d_reg_every'] + self.mixing_prob = train_opt['mixing_prob'] + + self.mean_path_length = 0 + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1) + if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC': + normal_params = [] + style_mlp_params = [] + modulation_conv_params = [] + for name, param in self.net_g.named_parameters(): + if 'modulation' in name: + normal_params.append(param) + elif 'style_mlp' in name: + style_mlp_params.append(param) + elif 'modulated_conv' in name: + modulation_conv_params.append(param) + else: + normal_params.append(param) + optim_params_g = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': style_mlp_params, + 'lr': train_opt['optim_g']['lr'] * 0.01 + }, + { + 'params': modulation_conv_params, + 'lr': train_opt['optim_g']['lr'] / 3 + } + ] + else: + normal_params = [] + for name, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # optimizer d + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC': + normal_params = [] + linear_params = [] + for name, param in self.net_d.named_parameters(): + if 'final_linear' in name: + linear_params.append(param) + else: + normal_params.append(param) + optim_params_d = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }, + { + 'params': linear_params, + 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512)) + } + ] + else: + normal_params = [] + for name, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + def feed_data(self, data): + self.real_img = data['gt'].to(self.device) + + def make_noise(self, batch, num_noise): + if num_noise == 1: + noises = torch.randn(batch, self.num_style_feat, device=self.device) + else: + noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0) + return noises + + def mixing_noise(self, batch, prob): + if random.random() < prob: + return self.make_noise(batch, 2) + else: + return [self.make_noise(batch, 1)] + + def optimize_parameters(self, current_iter): + loss_dict = OrderedDict() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + + batch = self.real_img.size(0) + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img.detach()) + + real_pred = self.net_d(self.real_img) + # wgan loss with softplus (logistic loss) for discriminator + l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In wgan, real_score should be positive and fake_score should be + # negative + loss_dict['real_score'] = real_pred.detach().mean() + loss_dict['fake_score'] = fake_pred.detach().mean() + l_d.backward() + + if current_iter % self.net_d_reg_every == 0: + self.real_img.requires_grad = True + real_pred = self.net_d(self.real_img) + l_d_r1 = r1_penalty(real_pred, self.real_img) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + # TODO: why do we need to add 0 * real_pred, otherwise, a runtime + # error will arise: RuntimeError: Expected to have finished + # reduction in the prior iteration before starting a new one. + # This error indicates that your module has parameters that were + # not used in producing loss. + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + noise = self.mixing_noise(batch, self.mixing_prob) + fake_img, _ = self.net_g(noise) + fake_pred = self.net_d(fake_img) + + # wgan loss with softplus (non-saturating loss) for generator + l_g = self.cri_gan(fake_pred, True, is_disc=False) + loss_dict['l_g'] = l_g + l_g.backward() + + if current_iter % self.net_g_reg_every == 0: + path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink']) + noise = self.mixing_noise(path_batch_size, self.mixing_prob) + fake_img, latents = self.net_g(noise, return_latents=True) + l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length) + + l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0]) + # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0] + l_g_path.backward() + loss_dict['l_g_path'] = l_g_path.detach().mean() + loss_dict['path_length'] = path_lengths + + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + def test(self): + with torch.no_grad(): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema([self.fixed_sample]) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + assert dataloader is None, 'Validation dataloader should be None.' + self.test() + result = tensor2img(self.output, min_max=(-1, 1)) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png') + imwrite(result, save_img_path) + # add sample images to tb_logger + result = (result / 255.).astype(np.float32) + result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) + if tb_logger is not None: + tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC') + + def save(self, epoch, current_iter): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/swinir_model.py b/basicsr/models/swinir_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac182f23b4a300aff14b2b45fcdca8c00da90c1 --- /dev/null +++ b/basicsr/models/swinir_model.py @@ -0,0 +1,33 @@ +import torch +from torch.nn import functional as F + +from basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class SwinIRModel(SRModel): + + def test(self): + # pad to multiplication of window_size + window_size = self.opt['network_g']['window_size'] + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = self.lq.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(img) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(img) + self.net_g.train() + + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7993a15e585526135d1ede094f4dcff47f64db --- /dev/null +++ b/basicsr/models/video_base_model.py @@ -0,0 +1,160 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VideoBaseModel(SRModel): + """Base video SR model.""" + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + # record all frames (border and center frames) + if rank == 0: + pbar = tqdm(total=len(dataset), unit='frame') + for idx in range(rank, len(dataset), world_size): + val_data = dataset[idx] + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + folder = val_data['folder'] + frame_idx, max_idx = val_data['idx'].split('/') + lq_path = val_data['lq_path'] + + self.feed_data(val_data) + self.test() + visuals = self.get_current_visuals() + result_img = tensor2img([visuals['result']]) + metric_data['img'] = result_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if 'vimeo' in dataset_name.lower(): # vimeo90k dataset + split_result = lq_path.split('/') + img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}' + else: # other datasets, e.g., REDS, Vid4 + img_name = osp.splitext(osp.basename(lq_path))[0] + + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f'{img_name}_{self.opt["name"]}.png') + imwrite(result_img, save_img_path) + + if with_metrics: + # calculate metrics + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][int(frame_idx), metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}') + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + else: + pass # assume use one gpu in non-dist testing + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + logger = get_root_logger() + logger.warning('nondist_validation is not implemented. Run dist_validation.') + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + # ----------------- calculate the average values for each folder, and for each metric ----------------- # + # average all frames for each sub-folder + # metric_results_avg is a dict:{ + # 'folder1': tensor (len(metrics)), + # 'folder2': tensor (len(metrics)) + # } + metric_results_avg = { + folder: torch.mean(tensor, dim=0).cpu() + for (folder, tensor) in self.metric_results.items() + } + # total_avg_results is a dict: { + # 'metric1': float, + # 'metric2': float + # } + total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + for folder, tensor in metric_results_avg.items(): + for idx, metric in enumerate(total_avg_results.keys()): + total_avg_results[metric] += metric_results_avg[folder][idx].item() + # average among folders + for metric in total_avg_results.keys(): + total_avg_results[metric] /= len(metric_results_avg) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter) + + # ------------------------------------------ log the metric ------------------------------------------ # + log_str = f'Validation {dataset_name}\n' + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + log_str += f'\t # {metric}: {value:.4f}' + for folder, tensor in metric_results_avg.items(): + log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric_idx, (metric, value) in enumerate(total_avg_results.items()): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + for folder, tensor in metric_results_avg.items(): + tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter) diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a2adcdeee59e494dd7d1c285919fac5c99cd9efb --- /dev/null +++ b/basicsr/models/video_gan_model.py @@ -0,0 +1,19 @@ +from basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoGANModel(SRGANModel, VideoBaseModel): + """Video GAN model. + + Use multiple inheritance. + It will first use the functions of :class:`SRGANModel`: + + - :func:`init_training_settings` + - :func:`setup_optimizers` + - :func:`optimize_parameters` + - :func:`save` + + Then find functions in :class:`VideoBaseModel`. + """ diff --git a/basicsr/models/video_recurrent_gan_model.py b/basicsr/models/video_recurrent_gan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..74cf81145c50ffafb220d22b51e56746dee5ba41 --- /dev/null +++ b/basicsr/models/video_recurrent_gan_model.py @@ -0,0 +1,180 @@ +import torch +from collections import OrderedDict + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.utils import get_root_logger +from basicsr.utils.registry import MODEL_REGISTRY +from .video_recurrent_model import VideoRecurrentModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentGANModel(VideoRecurrentModel): + + def init_training_settings(self): + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # build network net_g with Exponential Moving Average (EMA) + # net_g_ema only used for testing on one GPU and saving. + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_d', 'params') + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + if train_opt['fix_flow']: + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: # The fix_flow now only works for spynet. + flow_params.append(param) + else: + normal_params.append(param) + + optim_params = [ + { # add flow params first + 'params': flow_params, + 'lr': train_opt['lr_flow'] + }, + { + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + ] + else: + optim_params = self.net_g.parameters() + + # optimizer g + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + if self.fix_flow_iter: + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + _, _, c, h, w = self.output.size() + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w)) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output.view(-1, c, h, w)) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + # reshape to (b*n, c, h, w) + real_d_pred = self.net_d(self.gt.view(-1, c, h, w)) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + # reshape to (b*n, c, h, w) + fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py new file mode 100644 index 0000000000000000000000000000000000000000..796ee57d5aeb84e81fe8dc769facc8339798cc3e --- /dev/null +++ b/basicsr/models/video_recurrent_model.py @@ -0,0 +1,197 @@ +import torch +from collections import Counter +from os import path as osp +from torch import distributed as dist +from tqdm import tqdm + +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import MODEL_REGISTRY +from .video_base_model import VideoBaseModel + + +@MODEL_REGISTRY.register() +class VideoRecurrentModel(VideoBaseModel): + + def __init__(self, opt): + super(VideoRecurrentModel, self).__init__(opt) + if self.is_train: + self.fix_flow_iter = opt['train'].get('fix_flow') + + def setup_optimizers(self): + train_opt = self.opt['train'] + flow_lr_mul = train_opt.get('flow_lr_mul', 1) + logger = get_root_logger() + logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.') + if flow_lr_mul == 1: + optim_params = self.net_g.parameters() + else: # separate flow params and normal params for different lr + normal_params = [] + flow_params = [] + for name, param in self.net_g.named_parameters(): + if 'spynet' in name: + flow_params.append(param) + else: + normal_params.append(param) + optim_params = [ + { # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }, + { + 'params': flow_params, + 'lr': train_opt['optim_g']['lr'] * flow_lr_mul + }, + ] + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def optimize_parameters(self, current_iter): + if self.fix_flow_iter: + logger = get_root_logger() + if current_iter == 1: + logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') + for name, param in self.net_g.named_parameters(): + if 'spynet' in name or 'edvr' in name: + param.requires_grad_(False) + elif current_iter == self.fix_flow_iter: + logger.warning('Train all the parameters.') + self.net_g.requires_grad_(True) + + super(VideoRecurrentModel, self).optimize_parameters(current_iter) + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset = dataloader.dataset + dataset_name = dataset.opt['name'] + with_metrics = self.opt['val']['metrics'] is not None + # initialize self.metric_results + # It is a dict: { + # 'folder1': tensor (num_frame x len(metrics)), + # 'folder2': tensor (num_frame x len(metrics)) + # } + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + rank, world_size = get_dist_info() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() + + metric_data = dict() + num_folders = len(dataset) + num_pad = (world_size - (num_folders % world_size)) % world_size + if rank == 0: + pbar = tqdm(total=len(dataset), unit='folder') + # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded. + # (To avoid wait-dead) + for i in range(rank, num_folders + num_pad, world_size): + idx = min(i, num_folders - 1) + val_data = dataset[idx] + folder = val_data['folder'] + + # compute outputs + val_data['lq'].unsqueeze_(0) + val_data['gt'].unsqueeze_(0) + self.feed_data(val_data) + val_data['lq'].squeeze_(0) + val_data['gt'].squeeze_(0) + + self.test() + visuals = self.get_current_visuals() + + # tentative for out of GPU memory + del self.lq + del self.output + if 'gt' in visuals: + del self.gt + torch.cuda.empty_cache() + + if self.center_frame_only: + visuals['result'] = visuals['result'].unsqueeze(1) + if 'gt' in visuals: + visuals['gt'] = visuals['gt'].unsqueeze(1) + + # evaluate + if i < num_folders: + for idx in range(visuals['result'].size(1)): + result = visuals['result'][0, idx, :, :, :] + result_img = tensor2img([result]) # uint8, bgr + metric_data['img'] = result_img + if 'gt' in visuals: + gt = visuals['gt'][0, idx, :, :, :] + gt_img = tensor2img([gt]) # uint8, bgr + metric_data['img2'] = gt_img + + if save_img: + if self.opt['is_train']: + raise NotImplementedError('saving image is not supported during training.') + else: + if self.center_frame_only: # vimeo-90k + clip_ = val_data['lq_path'].split('/')[-3] + seq_ = val_data['lq_path'].split('/')[-2] + name_ = f'{clip_}_{seq_}' + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{name_}_{self.opt['name']}.png") + else: # others + img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder, + f"{idx:08d}_{self.opt['name']}.png") + # image name only for REDS dataset + imwrite(result_img, img_path) + + # calculate metrics + if with_metrics: + for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()): + result = calculate_metric(metric_data, opt_) + self.metric_results[folder][idx, metric_idx] += result + + # progress bar + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Folder: {folder}') + + if rank == 0: + pbar.close() + + if with_metrics: + if self.opt['dist']: + # collect data among GPUs + for _, tensor in self.metric_results.items(): + dist.reduce(tensor, 0) + dist.barrier() + + if rank == 0: + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def test(self): + n = self.lq.size(1) + self.net_g.eval() + + flip_seq = self.opt['val'].get('flip_seq', False) + self.center_frame_only = self.opt['val'].get('center_frame_only', False) + + if flip_seq: + self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1) + + with torch.no_grad(): + self.output = self.net_g(self.lq) + + if flip_seq: + output_1 = self.output[:, :n, :, :, :] + output_2 = self.output[:, n:, :, :, :].flip(1) + self.output = 0.5 * (output_1 + output_2) + + if self.center_frame_only: + self.output = self.output[:, n // 2, :, :, :] + + self.net_g.train() diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/basicsr/ops/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd770add360a14cfbfc2f9f6d6eaf9fc76707b13 Binary files /dev/null and b/basicsr/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff --- /dev/null +++ b/basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00c76fac1cc097dc5695c9bab30c17258133c742 Binary files /dev/null and b/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc b/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a0c47a985176753cf6d28415c1dc14e964fd515 Binary files /dev/null and b/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc differ diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6268ca825d59ef4a30d4d2156c4438cbbe9b3c1e --- /dev/null +++ b/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,379 @@ +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) +else: + try: + from . import deform_conv_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422 --- /dev/null +++ b/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..841671207d25a059649c695f10d6676fa41db268 Binary files /dev/null and b/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc b/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c481e75a5df24eacc08c57a01d1b74e60328004e Binary files /dev/null and b/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc differ diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..88edc445484b71119dc22a258e83aef49ce39b07 --- /dev/null +++ b/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import os +import torch +from torch import nn +from torch.autograd import Function + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) +else: + try: + from . import fused_act_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/basicsr/ops/upfirdn2d/__init__.py b/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd --- /dev/null +++ b/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160f3fc62d7714850fedc2303303884069f8d28c Binary files /dev/null and b/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc b/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001d0881eeb6bcb7011d47d06e9c48b8c14b628a Binary files /dev/null and b/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc differ diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d6122d59aa32fd52e956bd36200ba79af4a17b17 --- /dev/null +++ b/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,192 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import os +import torch +from torch.autograd import Function +from torch.nn import functional as F + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) +else: + try: + from . import upfirdn2d_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + _, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/basicsr/test.py b/basicsr/test.py new file mode 100644 index 0000000000000000000000000000000000000000..53cb3b7aa860c90518e15ba76e1a55fdf404bcc2 --- /dev/null +++ b/basicsr/test.py @@ -0,0 +1,45 @@ +import logging +import torch +from os import path as osp + +from basicsr.data import build_dataloader, build_dataset +from basicsr.models import build_model +from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from basicsr.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + # create model + model = build_model(opt) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/basicsr/train.py b/basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e02d98fe07f8c2924dda5b49f95adfa21990fa91 --- /dev/null +++ b/basicsr/train.py @@ -0,0 +1,215 @@ +import datetime +import logging +import math +import time +import torch +from os import path as osp + +from basicsr.data import build_dataloader, build_dataset +from basicsr.data.data_sampler import EnlargedSampler +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.models import build_model +from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, + init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) +from basicsr.utils.options import copy_opt_file, dict2str, parse_options + + +def init_tb_loggers(opt): + # initialize wandb logger before tensorboard logger to allow proper sync + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name'])) + return tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loaders = None, [] + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase.split('_')[0] == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}') + val_loaders.append(val_loader) + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loaders, total_epochs, total_iters + + +def load_resume_state(opt): + resume_state_path = None + if opt['auto_resume']: + state_path = osp.join('experiments', opt['name'], 'training_states') + if osp.isdir(state_path): + states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) + if len(states) != 0: + states = [float(v.split('.state')[0]) for v in states] + resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') + opt['path']['resume_state'] = resume_state_path + else: + if opt['path'].get('resume_state'): + resume_state_path = opt['path']['resume_state'] + + if resume_state_path is None: + resume_state = None + else: + device_id = torch.cuda.current_device() + resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + check_resume(opt, resume_state['iter']) + return resume_state + + +def train_pipeline(root_path): + # parse options, set distributed setting, set random seed + opt, args = parse_options(root_path, is_train=True) + opt['root_path'] = root_path + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + resume_state = load_resume_state(opt) + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name'])) + + # copy the yml file to the experiment root + copy_opt_file(args.opt, opt['path']['experiments_root']) + + # WARNING: should not use get_root_logger in the above codes, including the called functions + # Otherwise the logger will not be properly initialized + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loaders, total_epochs, total_iters = result + + # create model + model = build_model(opt) + if resume_state: # resume training + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_timer, iter_timer = AvgTimer(), AvgTimer() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_timer.record() + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_timer.record() + if current_iter == 1: + # reset start time in msg_logger for more accurate eta_time + # not work in resume mode + msg_logger.reset_start_time() + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): + if len(val_loaders) > 1: + logger.warning('Multiple validation datasets are *only* supported by SRModel.') + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_timer.start() + iter_timer.start() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9569c50780415b356c8e06edac5d960cf1fe1e91 --- /dev/null +++ b/basicsr/utils/__init__.py @@ -0,0 +1,47 @@ +from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb +from .diffjpeg import DiffJPEG +from .file_client import FileClient +from .img_process_util import USMSharp, usm_sharp +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt +from .options import yaml_load + +__all__ = [ + # color_util.py + 'bgr2ycbcr', + 'rgb2ycbcr', + 'rgb2ycbcr_pt', + 'ycbcr2bgr', + 'ycbcr2rgb', + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', + # diffjpeg + 'DiffJPEG', + # img_process_util + 'USMSharp', + 'usm_sharp', + # options + 'yaml_load' +] diff --git a/basicsr/utils/__pycache__/__init__.cpython-310.pyc b/basicsr/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef1a705154e127900a848704a4c76aafff0269d4 Binary files /dev/null and b/basicsr/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/color_util.cpython-310.pyc b/basicsr/utils/__pycache__/color_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..576a8cd964a3dd41e72a8716393da318a83a8273 Binary files /dev/null and b/basicsr/utils/__pycache__/color_util.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc b/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea950373c071f149e69493f9de069532ac0f8e5d Binary files /dev/null and b/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/dist_util.cpython-310.pyc b/basicsr/utils/__pycache__/dist_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8dbc93ef8125020211bd165e6f76dd5387d2ff3 Binary files /dev/null and b/basicsr/utils/__pycache__/dist_util.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/file_client.cpython-310.pyc b/basicsr/utils/__pycache__/file_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f627cdd480d8fb3481b480ce9b8b18a6c257dd0 Binary files /dev/null and b/basicsr/utils/__pycache__/file_client.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/flow_util.cpython-310.pyc b/basicsr/utils/__pycache__/flow_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76555d3a2048adca186eff33b7e514ab17944580 Binary files /dev/null and b/basicsr/utils/__pycache__/flow_util.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc b/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef699484d4152007212f77fd6deb861ee821c987 Binary files /dev/null and b/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/img_util.cpython-310.pyc b/basicsr/utils/__pycache__/img_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0022bc34275d5591e2b4f3113cc57a6325cddc Binary files /dev/null and b/basicsr/utils/__pycache__/img_util.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/logger.cpython-310.pyc b/basicsr/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a2ae8c7c35503a9da42e45b7ff5b43b0b97624a Binary files /dev/null and b/basicsr/utils/__pycache__/logger.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc b/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2089612e4b1e67dbbe99c774d7bed56b7edda7cd Binary files /dev/null and b/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/misc.cpython-310.pyc b/basicsr/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d14fbc10feb34106511e674dd21a50597644db22 Binary files /dev/null and b/basicsr/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/options.cpython-310.pyc b/basicsr/utils/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d68bd857da9c7eab8bbabb17329de9922f34f5a Binary files /dev/null and b/basicsr/utils/__pycache__/options.cpython-310.pyc differ diff --git a/basicsr/utils/__pycache__/registry.cpython-310.pyc b/basicsr/utils/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4dd273d40e29345391f53ce916e55bf99f60f3b Binary files /dev/null and b/basicsr/utils/__pycache__/registry.cpython-310.pyc differ diff --git a/basicsr/utils/color_util.py b/basicsr/utils/color_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4740d5c98dd0680654e20d46b81ab30dfe936d6e --- /dev/null +++ b/basicsr/utils/color_util.py @@ -0,0 +1,208 @@ +import numpy as np +import torch + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img diff --git a/basicsr/utils/diffjpeg.py b/basicsr/utils/diffjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..65f96b44f9e7f3f8a589668f0003adf328cc5742 --- /dev/null +++ b/basicsr/utils/diffjpeg.py @@ -0,0 +1,515 @@ +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered + + +if __name__ == '__main__': + import cv2 + + from basicsr.utils import img2tensor, tensor2img + + img_gt = cv2.imread('test.png') / 255. + + # -------------- cv2 -------------- # + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] + _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) + img_lq = np.float32(cv2.imdecode(encimg, 1)) + cv2.imwrite('cv2_JPEG_20.png', img_lq) + + # -------------- DiffJPEG -------------- # + jpeger = DiffJPEG(differentiable=False).cuda() + img_gt = img2tensor(img_gt) + img_gt = torch.stack([img_gt, img_gt]).cuda() + quality = img_gt.new_tensor([20, 40]) + out = jpeger(img_gt, quality=quality) + + cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) + cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f73abd0e1831b8cab6277d780331a5103785b9ec --- /dev/null +++ b/basicsr/utils/download_util.py @@ -0,0 +1,98 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92 --- /dev/null +++ b/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7180b4e9b5c8f2eb36a9a0e4ff6affdaae84b8 --- /dev/null +++ b/basicsr/utils/flow_util.py @@ -0,0 +1,170 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr diff --git a/basicsr/utils/img_process_util.py b/basicsr/utils/img_process_util.py new file mode 100644 index 0000000000000000000000000000000000000000..52e02f09930dbf13bcd12bbe16b76e4fce52578e --- /dev/null +++ b/basicsr/utils/img_process_util.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def usm_sharp(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fbce5dba5b01deb78f2453edc801a76e6a126998 --- /dev/null +++ b/basicsr/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1 and torch.is_tensor(tensor): + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b45ce01d5e32ddbf8354d71fd1c8678bede822 --- /dev/null +++ b/basicsr/utils/lmdb_util.py @@ -0,0 +1,199 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + + :: + + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..73553dc664781a061737e94880ea1c6788c09043 --- /dev/null +++ b/basicsr/utils/logger.py @@ -0,0 +1,213 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a201f79aaf030cdba710dd97c28af1b29a93ed2a --- /dev/null +++ b/basicsr/utils/matlab_functions.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d4a1403509672e85e74ac476e028cefb6dbb62 --- /dev/null +++ b/basicsr/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file size. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..3afd79c4f3e73f44f36503288c3959125ac3df34 --- /dev/null +++ b/basicsr/utils/options.py @@ -0,0 +1,210 @@ +import argparse +import os +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp + +from basicsr.utils import set_random_seed +from basicsr.utils.dist_util import get_dist_info, init_dist, master_only + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + tuple: yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + opt = yaml_load(args.opt) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/basicsr/utils/plot_util.py b/basicsr/utils/plot_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6da5bc29e706da87ab83af6d5367176fe78763 --- /dev/null +++ b/basicsr/utils/plot_util.py @@ -0,0 +1,83 @@ +import re + + +def read_data_from_tensorboard(log_path, tag): + """Get raw data (steps and values) from tensorboard events. + + Args: + log_path (str): Path to the tensorboard log. + tag (str): tag to be read. + """ + from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + # tensorboard event + event_acc = EventAccumulator(log_path) + event_acc.Reload() + scalar_list = event_acc.Tags()['scalars'] + print('tag list: ', scalar_list) + steps = [int(s.step) for s in event_acc.Scalars(tag)] + values = [s.value for s in event_acc.Scalars(tag)] + return steps, values + + +def read_data_from_txt_2v(path, pattern, step_one=False): + """Read data from txt with 2 returned values (usually [step, value]). + + Args: + path (str): path to the txt file. + pattern (str): re (regular expression) pattern. + step_one (bool): add 1 to steps. Default: False. + """ + with open(path) as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + steps = [] + values = [] + + pattern = re.compile(pattern) + for line in lines: + match = pattern.match(line) + if match: + steps.append(int(match.group(1))) + values.append(float(match.group(2))) + if step_one: + steps = [v + 1 for v in steps] + return steps, values + + +def read_data_from_txt_1v(path, pattern): + """Read data from txt with 1 returned values. + + Args: + path (str): path to the txt file. + pattern (str): re (regular expression) pattern. + """ + with open(path) as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + data = [] + + pattern = re.compile(pattern) + for line in lines: + match = pattern.match(line) + if match: + data.append(float(match.group(1))) + return data + + +def smooth_data(values, smooth_weight): + """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). + + Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 + + Args: + values (list): A list of values to be smoothed. + smooth_weight (float): Smooth weight. + """ + values_sm = [] + last_sm_value = values[0] + for value in values: + value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value + values_sm.append(value_sm) + last_sm_value = value_sm + return values_sm diff --git a/basicsr/utils/realesrgan_utils.py b/basicsr/utils/realesrgan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff934e5150b4aa568a51ab9614a2057b011a6014 --- /dev/null +++ b/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,293 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from basicsr.utils.download_util import load_file_from_url +from torch.nn import functional as F + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5e72ef7ff21b94f50e6caa8948f69ca0b04bc968 --- /dev/null +++ b/basicsr/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/configs/sr.yaml b/configs/sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ba43a29fde4690dc9483f6fb536ca5dbe8f6e18 --- /dev/null +++ b/configs/sr.yaml @@ -0,0 +1,110 @@ +sf: 4 +degradation: + # the first degradation process + resize_prob: [0.2, 0.7, 0.1] # up, down, keep + resize_range: [0.3, 1.5] + gaussian_noise_prob: 0.5 + noise_range: [1, 15] + poisson_scale_range: [0.05, 2.0] + gray_noise_prob: 0.4 + jpeg_range: [60, 95] + + # the second degradation process + second_blur_prob: 0.5 + resize_prob2: [0.3, 0.4, 0.3] # up, down, keep + resize_range2: [0.6, 1.2] + gaussian_noise_prob2: 0.5 + noise_range2: [1, 12] + poisson_scale_range2: [0.05, 1.0] + gray_noise_prob2: 0.4 + jpeg_range2: [60, 100] + + gt_size: 512 + no_degradation_prob: 0.01 + +train: + queue_size: 180 + gt_path: ['dataset_path/LSDIR/'] + face_gt_path: 'dataset_path/FFHQ/' + num_face: 10000 + crop_size: 512 + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 1.5] + betag_range: [0.5, 2.0] + betap_range: [1, 1.5] + + blur_kernel_size2: 11 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.0] + betag_range2: [0.5, 2.0] + betap_range2: [1, 1.5] + + final_sinc_prob: 0.8 + + gt_size: 512 + use_hflip: True + use_rot: False + +validation: + gt_path: dataset_path/DIV2K_valid_HR/ + crop_size: 512 + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 1.5] + betag_range: [0.5, 2.0] + betap_range: [1, 1.5] + + blur_kernel_size2: 11 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.0] + betag_range2: [0.5, 2.0] + betap_range2: [1, 1.5] + + final_sinc_prob: 0.8 + + gt_size: 512 + use_hflip: True + use_rot: False + +test: + gt_path: dataset_path/DIV2K_valid_HR/ + crop_size: 512 + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 1.5] + betag_range: [0.5, 2.0] + betap_range: [1, 1.5] + + blur_kernel_size2: 11 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.0] + betag_range2: [0.5, 2.0] + betap_range2: [1, 1.5] + + final_sinc_prob: 0.8 + + gt_size: 512 + use_hflip: True + use_rot: False \ No newline at end of file diff --git a/configs/sr_test.yaml b/configs/sr_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ae403a0275e35a3637264946db822b4bc6e4855 --- /dev/null +++ b/configs/sr_test.yaml @@ -0,0 +1,6 @@ +sf: 4 + +validation: + lr_path: path_to_LR_image_folder + io_backend: + type: disk \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0730f5e21cd6c4b75b7659e989a6fc51c61e40f6 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,35 @@ +name: s3diff +channels: + - pytorch + - defaults +dependencies: + - python=3.10 + - pip: + - einops>=0.6.1 + - numpy>=1.24.4 + - open-clip-torch>=2.20.0 + - opencv-python==4.6.0.66 + - pillow>=9.5.0 + - scipy==1.11.1 + - timm>=0.9.2 + - tokenizers + - torch>=2.0.1 + + - torchaudio>=2.0.2 + - torchdata==0.6.1 + - torchmetrics>=1.0.1 + - torchvision>=0.15.2 + + - tqdm>=4.65.0 + - transformers==4.35.2 + - triton==2.0.0 + - urllib3<1.27,>=1.25.4 + - xformers>=0.0.20 + - streamlit-keyup==0.2.0 + - lpips + - peft + - pyiqa + - omegaconf + - dominate + - diffusers==0.25.1 + - gradio==3.43.1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..177c3c3edd583cead7360e95b5d4774034218662 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +einops>=0.6.1 +numpy>=1.24.4 +open-clip-torch>=2.20.0 +opencv-python==4.6.0.66 +pillow>=9.5.0 +scipy==1.11.1 +timm>=0.9.2 +tokenizers + +torch>=2.0.1 +torchaudio>=2.0.2 +torchdata==0.6.1 +torchmetrics>=1.0.1 +torchvision>=0.15.2 + +tqdm>=4.65.0 +transformers==4.35.2 +triton==2.0.0 +urllib3<1.27,>=1.25.4 +xformers>=0.0.20 +streamlit-keyup==0.2.0 +lpips +peft +pyiqa +omegaconf +dominate +diffusers==0.25.1 +gradio==3.43.1 diff --git a/run_evaluate.sh b/run_evaluate.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a47eff30dee5e0297655ea88a1531b7f3e80749 --- /dev/null +++ b/run_evaluate.sh @@ -0,0 +1 @@ +python src/evaluate_img.py -i "path_to_generated_HR" -r "path_to_ground_truth" diff --git a/run_inference.sh b/run_inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..57cab9d9fbdc349059325b8693dfd877aef02943 --- /dev/null +++ b/run_inference.sh @@ -0,0 +1,7 @@ +accelerate launch --num_processes=1 --gpu_ids="0," --main_process_port 29300 src/inference_s3diff.py \ + --sd_path="path_to_checkpoints/sd-turbo" \ + --de_net_path="assets/mm-realsr/de_net.pth" \ + --pretrained_path="path_to_checkpoints_folder/model_30001.pkl" \ + --output_dir="./output" \ + --ref_path="path_to_ground_truth_folder" \ + --align_method="wavelet" diff --git a/run_training.sh b/run_training.sh new file mode 100644 index 0000000000000000000000000000000000000000..c1ab504323fd203cae2445a7dea10181207a3f26 --- /dev/null +++ b/run_training.sh @@ -0,0 +1,8 @@ +accelerate launch --num_processes=4 --gpu_ids="0,1,2,3" --main_process_port 29300 src/train_s3diff.py \ + --sd_path="path_to_checkpoints/sd-turbo" \ + --de_net_path="assets/mm-realsr/de_net.pth" \ + --output_dir="./output" \ + --resolution=512 \ + --train_batch_size=4 \ + --enable_xformers_memory_efficient_attention \ + --viz_freq 25 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1b74bb6c0195ef838585c37e9c95c45db9e98a16 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='S3Diff', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) diff --git a/src/de_net.py b/src/de_net.py new file mode 100644 index 0000000000000000000000000000000000000000..0465d9d6c043ed8171c3ceb178f25e05ca5fc558 --- /dev/null +++ b/src/de_net.py @@ -0,0 +1,127 @@ +import torch +import copy +from torch import nn as nn +from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights + +class DEResNet(nn.Module): + """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore + + As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal', + resnet arch works for image quality estimation. + + Args: + num_in_ch (int): channel number of inputs. Default: 3. + num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise). + degradation_embed_size (int): embedding size of each degradation vector. + degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid. + num_feats (list): channel number of each stage. + num_blocks (list): residual block of each stage. + downscales (list): downscales of each stage. + """ + + def __init__(self, + num_in_ch=3, + num_degradation=2, + degradation_degree_actv='sigmoid', + num_feats=[64, 64, 64, 128], + num_blocks=[2, 2, 2, 2], + downscales=[1, 1, 2, 1]): + super(DEResNet, self).__init__() + + assert isinstance(num_feats, list) + assert isinstance(num_blocks, list) + assert isinstance(downscales, list) + assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales) + + num_stage = len(num_feats) + + self.conv_first = nn.ModuleList() + for _ in range(num_degradation): + self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1)) + self.body = nn.ModuleList() + for _ in range(num_degradation): + body = list() + for stage in range(num_stage): + for _ in range(num_blocks[stage]): + body.append(ResidualBlockNoBN(num_feats[stage])) + if downscales[stage] == 1: + if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]: + body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1)) + continue + elif downscales[stage] == 2: + body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1)) + else: + raise NotImplementedError + self.body.append(nn.Sequential(*body)) + + self.num_degradation = num_degradation + self.fc_degree = nn.ModuleList() + if degradation_degree_actv == 'sigmoid': + actv = nn.Sigmoid + elif degradation_degree_actv == 'tanh': + actv = nn.Tanh + else: + raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, ' + f'{degradation_degree_actv} is not supported yet.') + for _ in range(num_degradation): + self.fc_degree.append( + nn.Sequential( + nn.Linear(num_feats[-1], 512), + nn.ReLU(inplace=True), + nn.Linear(512, 1), + actv(), + )) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1) + + def clone_module(self, module): + new_module = copy.deepcopy(module) + return new_module + + def average_parameters(self, modules): + avg_module = self.clone_module(modules[0]) + for name, param in avg_module.named_parameters(): + avg_param = sum([mod.state_dict()[name].data for mod in modules]) / len(modules) + param.data.copy_(avg_param) + return avg_module + + def expand_degradation_modules(self, new_num_degradation): + if new_num_degradation <= self.num_degradation: + return + initial_modules = [self.conv_first, self.body, self.fc_degree] + + for modules in initial_modules: + avg_module = self.average_parameters(modules[:2]) + while len(modules) < new_num_degradation: + modules.append(self.clone_module(avg_module)) + + def load_and_expand_model(self, path, num_degradation): + state_dict = torch.load(path, map_location=torch.device('cpu')) + self.load_state_dict(state_dict, strict=True) + + self.expand_degradation_modules(num_degradation) + self.num_degradation = num_degradation + + def load_model(self, path): + state_dict = torch.load(path, map_location=torch.device('cpu')) + self.load_state_dict(state_dict, strict=True) + + def set_train(self): + self.conv_first.requires_grad_(True) + self.fc_degree.requires_grad_(True) + for n, _p in self.body.named_parameters(): + if "lora" in n: + _p.requires_grad = True + + def forward(self, x): + degrees = [] + for i in range(self.num_degradation): + x_out = self.conv_first[i](x) + feat = self.body[i](x_out) + feat = self.avg_pool(feat) + feat = feat.squeeze(-1).squeeze(-1) + # for i in range(self.num_degradation): + degrees.append(self.fc_degree[i](feat).squeeze(-1)) + return torch.stack(degrees, dim=1) \ No newline at end of file diff --git a/src/evaluate_img.py b/src/evaluate_img.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac421da6f30d0b57d611d60478562ae8219616b --- /dev/null +++ b/src/evaluate_img.py @@ -0,0 +1,72 @@ +import pyiqa +import os +import argparse +from pathlib import Path +import torch +from utils import util_image +import tqdm + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +print(pyiqa.list_models()) +def evaluate(in_path, ref_path, ntest): + metric_dict = {} + metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device) + metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device) + metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device) + metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device) + metric_paired_dict = {} + + in_path = Path(in_path) if not isinstance(in_path, Path) else in_path + assert in_path.is_dir() + + ref_path_list = None + if ref_path is not None: + ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path + ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")]) + if ntest is not None: ref_path_list = ref_path_list[:ntest] + + metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device) + metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device) + metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device) + metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device) + + lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")]) + if ntest is not None: lr_path_list = lr_path_list[:ntest] + + print(f'Find {len(lr_path_list)} images in {in_path}') + result = {} + for i in tqdm.tqdm(range(len(lr_path_list))): + _in_path = lr_path_list[i] + _ref_path = ref_path_list[i] if ref_path_list is not None else None + + im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') # h x w x c + im_in_tensor = util_image.img2tensor(im_in).cuda() # 1 x c x h x w + for key, metric in metric_dict.items(): + with torch.cuda.amp.autocast(): + result[key] = result.get(key, 0) + metric(im_in_tensor).item() + + if ref_path is not None: + im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32') # h x w x c + im_ref_tensor = util_image.img2tensor(im_ref).cuda() + for key, metric in metric_paired_dict.items(): + result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item() + + if ref_path is not None: + fid_metric = pyiqa.create_metric('fid') + result['fid'] = fid_metric(in_path, ref_path) + + for key, res in result.items(): + if key == 'fid': + print(f"{key}: {res:.2f}") + else: + print(f"{key}: {res/len(lr_path_list):.5f}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-i',"--in_path", type=str, required=True) + parser.add_argument("-r", "--ref_path", type=str, default=None) + parser.add_argument("--ntest", type=int, default=None) + args = parser.parse_args() + evaluate(args.in_path, args.ref_path, args.ntest) + \ No newline at end of file diff --git a/src/gradio_s3diff.py b/src/gradio_s3diff.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc1a6d18ac2f92692a51829559bc86d9550ca64 --- /dev/null +++ b/src/gradio_s3diff.py @@ -0,0 +1,157 @@ +import gradio as gr +import os +import sys +import math +from typing import List + +import numpy as np +from PIL import Image + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers.utils.import_utils import is_xformers_available + +from my_utils.testing_utils import parse_args_paired_testing +from de_net import DEResNet +from s3diff_tile import S3Diff +from torchvision import transforms +from utils.wavelet_color import wavelet_color_fix, adain_color_fix + +tensor_transforms = transforms.Compose([ + transforms.ToTensor(), + ]) + +args = parse_args_paired_testing() + +# Load scheduler, tokenizer and models. +pretrained_model_path = 'checkpoint-path/s3diff.pkl' +t2i_path = 'sd-turbo-path' +de_net_path = 'assets/mm-realsr/de_net.pth' + +# initialize net_sr +net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=t2i_path, pretrained_path=pretrained_model_path, args=args) +net_sr.set_eval() + +# initalize degradation estimation network +net_de = DEResNet(num_in_ch=3, num_degradation=2) +net_de.load_model(de_net_path) +net_de = net_de.cuda() +net_de.eval() + +if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + net_sr.unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + +if args.gradient_checkpointing: + net_sr.unet.enable_gradient_checkpointing() + +weight_dtype = torch.float32 +device = "cuda" + +# Move text_encode and vae to gpu and cast to weight_dtype +net_sr.to(device, dtype=weight_dtype) +net_de.to(device, dtype=weight_dtype) + +@torch.no_grad() +def process( + input_image: Image.Image, + scale_factor: float, + cfg_scale: float, + latent_tiled_size: int, + latent_tiled_overlap: int, + align_method: str, + ) -> List[np.ndarray]: + + # positive_prompt = "" + # negative_prompt = "" + + net_sr._set_latent_tile(latent_tiled_size = latent_tiled_size, latent_tiled_overlap = latent_tiled_overlap) + + im_lr = tensor_transforms(input_image).unsqueeze(0).to(device) + ori_h, ori_w = im_lr.shape[2:] + im_lr_resize = F.interpolate( + im_lr, + size=(int(ori_h * scale_factor), + int(ori_w * scale_factor)), + mode='bicubic', + ) + im_lr_resize = im_lr_resize.contiguous() + im_lr_resize_norm = im_lr_resize * 2 - 1.0 + im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0) + resize_h, resize_w = im_lr_resize_norm.shape[2:] + + pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h + pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w + im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect') + + try: + with torch.autocast("cuda"): + deg_score = net_de(im_lr) + + pos_tag_prompt = [args.pos_prompt] + neg_tag_prompt = [args.neg_prompt] + + x_tgt_pred = net_sr(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt) + x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w] + out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach() + + output_pil = transforms.ToPILImage()(out_img[0]) + + if align_method == 'no fix': + image = output_pil + else: + im_lr_resize = transforms.ToPILImage()(im_lr_resize[0]) + if align_method == 'wavelet': + image = wavelet_color_fix(output_pil, im_lr_resize) + elif align_method == 'adain': + image = adain_color_fix(output_pil, im_lr_resize) + + except Exception as e: + print(e) + image = Image.new(mode="RGB", size=(512, 512)) + + return image + + +# +MARKDOWN = \ +""" +## Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors + +[GitHub](https://github.com/ArcticHare105/S3Diff) | [Paper](https://arxiv.org/abs/2409.17058) + +If S3Diff is helpful for you, please help star the GitHub Repo. Thanks! +""" + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown(MARKDOWN) + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source="upload", type="pil") + run_button = gr.Button(label="Run") + with gr.Accordion("Options", open=True): + cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=1.0, maximum=1.1, value=1.07, step=0.01) + scale_factor = gr.Number(label="SR Scale", value=4) + latent_tiled_size = gr.Slider(label="Tile Size", minimum=64, maximum=160, value=96, step=1) + latent_tiled_overlap = gr.Slider(label="Tile Overlap", minimum=16, maximum=48, value=32, step=1) + align_method = gr.Dropdown(label="Color Correction", choices=["wavelet", "adain", "no fix"], value="wavelet") + with gr.Column(): + result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", source="canvas", width="100%", height="auto") + + inputs = [ + input_image, + scale_factor, + cfg_scale, + latent_tiled_size, + latent_tiled_overlap, + align_method + ] + run_button.click(fn=process, inputs=inputs, outputs=[result_image]) + +block.launch() + diff --git a/src/inference_s3diff.py b/src/inference_s3diff.py new file mode 100644 index 0000000000000000000000000000000000000000..1240141c523b7d33769068b713634924f3c670c1 --- /dev/null +++ b/src/inference_s3diff.py @@ -0,0 +1,218 @@ +import os +import gc +import tqdm +import math +import lpips +import pyiqa +import argparse +import clip +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers + +from omegaconf import OmegaConf +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image +from torchvision import transforms +# from tqdm.auto import tqdm + +import diffusers +import utils.misc as misc + +from diffusers.utils.import_utils import is_xformers_available +from diffusers.optimization import get_scheduler + +from de_net import DEResNet +from s3diff_tile import S3Diff +from my_utils.testing_utils import parse_args_paired_testing, PlainDataset, lr_proc +from utils.util_image import ImageSpliterTh +from my_utils.utils import instantiate_from_config +from pathlib import Path +from utils import util_image +from utils.wavelet_color import wavelet_color_fix, adain_color_fix + +def evaluate(in_path, ref_path, ntest): + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + metric_dict = {} + metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device) + metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device) + metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device) + metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device) + metric_paired_dict = {} + + in_path = Path(in_path) if not isinstance(in_path, Path) else in_path + assert in_path.is_dir() + + ref_path_list = None + if ref_path is not None: + ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path + ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")]) + if ntest is not None: ref_path_list = ref_path_list[:ntest] + + metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device) + metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device) + metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device) + metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device) + + lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")]) + if ntest is not None: lr_path_list = lr_path_list[:ntest] + + print(f'Find {len(lr_path_list)} images in {in_path}') + result = {} + for i in tqdm.tqdm(range(len(lr_path_list))): + _in_path = lr_path_list[i] + _ref_path = ref_path_list[i] if ref_path_list is not None else None + + im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') # h x w x c + im_in_tensor = util_image.img2tensor(im_in).cuda() # 1 x c x h x w + for key, metric in metric_dict.items(): + with torch.cuda.amp.autocast(): + result[key] = result.get(key, 0) + metric(im_in_tensor).item() + + if ref_path is not None: + im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32') # h x w x c + im_ref_tensor = util_image.img2tensor(im_ref).cuda() + for key, metric in metric_paired_dict.items(): + result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item() + + if ref_path is not None: + fid_metric = pyiqa.create_metric('fid') + result['fid'] = fid_metric(in_path, ref_path) + + print_results = [] + for key, res in result.items(): + if key == 'fid': + print(f"{key}: {res:.2f}") + print_results.append(f"{key}: {res:.2f}") + else: + print(f"{key}: {res/len(lr_path_list):.5f}") + print_results.append(f"{key}: {res/len(lr_path_list):.5f}") + return print_results + + +def main(args): + config = OmegaConf.load(args.base_config) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + ) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) + + # initialize net_sr + net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path, args=args) + net_sr.set_eval() + + net_de = DEResNet(num_in_ch=3, num_degradation=2) + net_de.load_model(args.de_net_path) + net_de = net_de.cuda() + net_de.eval() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + net_sr.unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available, please install it by running `pip install xformers`") + + if args.gradient_checkpointing: + net_sr.unet.enable_gradient_checkpointing() + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + dataset_val = PlainDataset(config.validation) + dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) + + # Prepare everything with our `accelerator`. + net_sr, net_de = accelerator.prepare(net_sr, net_de) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move al networksr to device and cast to weight_dtype + net_sr.to(accelerator.device, dtype=weight_dtype) + net_de.to(accelerator.device, dtype=weight_dtype) + + offset = args.padding_offset + for step, batch_val in enumerate(dl_val): + lr_path = batch_val['lr_path'][0] + (path, name) = os.path.split(lr_path) + + im_lr = batch_val['lr'].cuda() + im_lr = im_lr.to(memory_format=torch.contiguous_format).float() + + ori_h, ori_w = im_lr.shape[2:] + im_lr_resize = F.interpolate( + im_lr, + size=(ori_h * config.sf, + ori_w * config.sf), + mode='bicubic', + ) + + im_lr_resize = im_lr_resize.contiguous() + im_lr_resize_norm = im_lr_resize * 2 - 1.0 + im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0) + resize_h, resize_w = im_lr_resize_norm.shape[2:] + + pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h + pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w + im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect') + + B = im_lr_resize.size(0) + with torch.no_grad(): + # forward pass + deg_score = net_de(im_lr) + pos_tag_prompt = [args.pos_prompt for _ in range(B)] + neg_tag_prompt = [args.neg_prompt for _ in range(B)] + x_tgt_pred = accelerator.unwrap_model(net_sr)(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt) + x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w] + out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach() + + output_pil = transforms.ToPILImage()(out_img[0]) + + if args.align_method == 'nofix': + output_pil = output_pil + else: + im_lr_resize = transforms.ToPILImage()(im_lr_resize[0].cpu().detach()) + if args.align_method == 'wavelet': + output_pil = wavelet_color_fix(output_pil, im_lr_resize) + elif args.align_method == 'adain': + output_pil = adain_color_fix(output_pil, im_lr_resize) + + fname, ext = os.path.splitext(name) + outf = os.path.join(args.output_dir, fname+'.png') + output_pil.save(outf) + + print_results = evaluate(args.output_dir, args.ref_path, None) + out_t = os.path.join(args.output_dir, 'results.txt') + with open(out_t, 'w', encoding='utf-8') as f: + for item in print_results: + f.write(f"{item}\n") + + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + args = parse_args_paired_testing() + main(args) diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb96328092875072b0b0e4447d3539e5371c312 --- /dev/null +++ b/src/model.py @@ -0,0 +1,80 @@ +import torch +import os +import requests +from tqdm import tqdm +from diffusers import DDPMScheduler, EulerDiscreteScheduler +from typing import Any, Optional, Union + +# def make_1step_sched(pretrained_path, step=4): +# noise_scheduler_1step = EulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler") +# noise_scheduler_1step.set_timesteps(step, device="cuda") +# noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() + # return noise_scheduler_1step + + +def make_1step_sched(pretrained_path): + noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler") + noise_scheduler_1step.set_timesteps(1, device="cuda") + noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() + return noise_scheduler_1step + + +def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + _tmp = lora_A(dropout(x)) + if isinstance(lora_A, torch.nn.Conv2d): + _tmp = torch.einsum('...khw,...kr->...rhw', _tmp, self.de_mod) + elif isinstance(lora_A, torch.nn.Linear): + _tmp = torch.einsum('...lk,...kr->...lr', _tmp, self.de_mod) + else: + raise NotImplementedError('only conv and linear are supported yet.') + + result = result + lora_B(_tmp) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + result = result.to(torch_result_dtype) + + return result + +def download_url(url, outf): + if not os.path.exists(outf): + print(f"Downloading checkpoint to {outf}") + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(outf, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + print("ERROR, something went wrong") + print(f"Downloaded successfully to {outf}") + else: + print(f"Skipping download, {outf} already exists") diff --git a/src/my_utils/devices.py b/src/my_utils/devices.py new file mode 100644 index 0000000000000000000000000000000000000000..7313838c3c1f153816031dc70c8beb765751ed9e --- /dev/null +++ b/src/my_utils/devices.py @@ -0,0 +1,138 @@ +import sys +import contextlib +from functools import lru_cache + +import torch +#from modules import errors + +if sys.platform == "darwin": + from modules import mac_specific + + +def has_mps() -> bool: + if sys.platform != "darwin": + return False + else: + return mac_specific.has_mps + + +def get_cuda_device_string(): + return "cuda" + + +def get_optimal_device_name(): + if torch.cuda.is_available(): + return get_cuda_device_string() + + if has_mps(): + return "mps" + + return "cpu" + + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) + + +def get_device_for(task): + return get_optimal_device() + + +def torch_gc(): + + if torch.cuda.is_available(): + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + if has_mps(): + mac_specific.torch_mps_gc() + + +def enable_tf32(): + if torch.cuda.is_available(): + + # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't + # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 + if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): + torch.backends.cudnn.benchmark = True + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +enable_tf32() +#errors.run(enable_tf32, "Enabling TF32") + +cpu = torch.device("cpu") +device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") +dtype = torch.float16 +dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False + + +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + +def randn(seed, shape): + torch.manual_seed(seed) + return torch.randn(shape, device=device) + + +def randn_without_seed(shape): + return torch.randn(shape, device=device) + + +def autocast(disable=False): + if disable: + return contextlib.nullcontext() + + return torch.autocast("cuda") + + +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + +class NansException(Exception): + pass + + +def test_for_nans(x, where): + if not torch.all(torch.isnan(x)).item(): + return + + if where == "unet": + message = "A tensor with all NaNs was produced in Unet." + + elif where == "vae": + message = "A tensor with all NaNs was produced in VAE." + + else: + message = "A tensor with all NaNs was produced." + + message += " Use --disable-nan-check commandline argument to disable this check." + + raise NansException(message) + + +@lru_cache +def first_time_calculation(): + """ + just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and + spends about 2.7 seconds doing that, at least wih NVidia. + """ + + x = torch.zeros((1, 1)).to(device, dtype) + linear = torch.nn.Linear(1, 1).to(device, dtype) + linear(x) + + x = torch.zeros((1, 1, 3, 3)).to(device, dtype) + conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) + conv2d(x) diff --git a/src/my_utils/dino_struct.py b/src/my_utils/dino_struct.py new file mode 100644 index 0000000000000000000000000000000000000000..d2721c9b61b5fbef650e5c9e2133c93a6b6a4ea4 --- /dev/null +++ b/src/my_utils/dino_struct.py @@ -0,0 +1,185 @@ +import torch +import torchvision +import torch.nn.functional as F + + +def attn_cosine_sim(x, eps=1e-08): + x = x[0] # TEMP: getting rid of redundant dimension, TBF + norm1 = x.norm(dim=2, keepdim=True) + factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps) + sim_matrix = (x @ x.permute(0, 2, 1)) / factor + return sim_matrix + + +class VitExtractor: + BLOCK_KEY = 'block' + ATTN_KEY = 'attn' + PATCH_IMD_KEY = 'patch_imd' + QKV_KEY = 'qkv' + KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY] + + def __init__(self, model_name, device): + # pdb.set_trace() + self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device) + self.model.eval() + self.model_name = model_name + self.hook_handlers = [] + self.layers_dict = {} + self.outputs_dict = {} + for key in VitExtractor.KEY_LIST: + self.layers_dict[key] = [] + self.outputs_dict[key] = [] + self._init_hooks_data() + + def _init_hooks_data(self): + self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + for key in VitExtractor.KEY_LIST: + # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else [] + self.outputs_dict[key] = [] + + def _register_hooks(self, **kwargs): + for block_idx, block in enumerate(self.model.blocks): + if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]: + self.hook_handlers.append(block.register_forward_hook(self._get_block_hook())) + if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]: + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook())) + if block_idx in self.layers_dict[VitExtractor.QKV_KEY]: + self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook())) + if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook())) + + def _clear_hooks(self): + for handler in self.hook_handlers: + handler.remove() + self.hook_handlers = [] + + def _get_block_hook(self): + def _get_block_output(model, input, output): + self.outputs_dict[VitExtractor.BLOCK_KEY].append(output) + + return _get_block_output + + def _get_attn_hook(self): + def _get_attn_output(model, inp, output): + self.outputs_dict[VitExtractor.ATTN_KEY].append(output) + + return _get_attn_output + + def _get_qkv_hook(self): + def _get_qkv_output(model, inp, output): + self.outputs_dict[VitExtractor.QKV_KEY].append(output) + + return _get_qkv_output + + # TODO: CHECK ATTN OUTPUT TUPLE + def _get_patch_imd_hook(self): + def _get_attn_output(model, inp, output): + self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0]) + + return _get_attn_output + + def get_feature_from_input(self, input_img): # List([B, N, D]) + self._register_hooks() + self.model(input_img) + feature = self.outputs_dict[VitExtractor.BLOCK_KEY] + self._clear_hooks() + self._init_hooks_data() + return feature + + def get_qkv_feature_from_input(self, input_img): + self._register_hooks() + self.model(input_img) + feature = self.outputs_dict[VitExtractor.QKV_KEY] + self._clear_hooks() + self._init_hooks_data() + return feature + + def get_attn_feature_from_input(self, input_img): + self._register_hooks() + self.model(input_img) + feature = self.outputs_dict[VitExtractor.ATTN_KEY] + self._clear_hooks() + self._init_hooks_data() + return feature + + def get_patch_size(self): + return 8 if "8" in self.model_name else 16 + + def get_width_patch_num(self, input_img_shape): + b, c, h, w = input_img_shape + patch_size = self.get_patch_size() + return w // patch_size + + def get_height_patch_num(self, input_img_shape): + b, c, h, w = input_img_shape + patch_size = self.get_patch_size() + return h // patch_size + + def get_patch_num(self, input_img_shape): + patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape)) + return patch_num + + def get_head_num(self): + if "dino" in self.model_name: + return 6 if "s" in self.model_name else 12 + return 6 if "small" in self.model_name else 12 + + def get_embedding_dim(self): + if "dino" in self.model_name: + return 384 if "s" in self.model_name else 768 + return 384 if "small" in self.model_name else 768 + + def get_queries_from_qkv(self, qkv, input_img_shape): + patch_num = self.get_patch_num(input_img_shape) + head_num = self.get_head_num() + embedding_dim = self.get_embedding_dim() + q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0] + return q + + def get_keys_from_qkv(self, qkv, input_img_shape): + patch_num = self.get_patch_num(input_img_shape) + head_num = self.get_head_num() + embedding_dim = self.get_embedding_dim() + k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1] + return k + + def get_values_from_qkv(self, qkv, input_img_shape): + patch_num = self.get_patch_num(input_img_shape) + head_num = self.get_head_num() + embedding_dim = self.get_embedding_dim() + v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2] + return v + + def get_keys_from_input(self, input_img, layer_num): + qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num] + keys = self.get_keys_from_qkv(qkv_features, input_img.shape) + return keys + + def get_keys_self_sim_from_input(self, input_img, layer_num): + keys = self.get_keys_from_input(input_img, layer_num=layer_num) + h, t, d = keys.shape + concatenated_keys = keys.transpose(0, 1).reshape(t, h * d) + ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...]) + return ssim_map + + +class DinoStructureLoss: + def __init__(self, ): + self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda") + self.preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + + def calculate_global_ssim_loss(self, outputs, inputs): + loss = 0.0 + for a, b in zip(inputs, outputs): # avoid memory limitations + with torch.no_grad(): + target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11) + keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11) + loss += F.mse_loss(keys_ssim, target_keys_self_sim) + return loss diff --git a/src/my_utils/testing_utils.py b/src/my_utils/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4667c5f82cbcf83da5fd4f8fc29e7dbb88337a --- /dev/null +++ b/src/my_utils/testing_utils.py @@ -0,0 +1,210 @@ +import argparse +import json +from PIL import Image +from torchvision import transforms +import torch.nn.functional as F +from glob import glob + +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from pathlib import Path +from torch.utils import data as data + +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.data.transforms import paired_random_crop, triplet_random_crop +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian + +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +def parse_args_paired_testing(input_args=None): + """ + Parses command-line arguments used for configuring an paired session (pix2pix-Turbo). + This function sets up an argument parser to handle various training options. + + Returns: + argparse.Namespace: The parsed command-line arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--ref_path", type=str, default=None,) + parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str) + parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.") + + # details about the model architecture + parser.add_argument("--sd_path") + parser.add_argument("--de_net_path") + parser.add_argument("--pretrained_path", type=str, default=None,) + parser.add_argument("--revision", type=str, default=None,) + parser.add_argument("--variant", type=str, default=None,) + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--lora_rank_unet", default=32, type=int) + parser.add_argument("--lora_rank_vae", default=16, type=int) + + parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.") + parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.") + parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.") + parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.") + + parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) + parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) + parser.add_argument("--latent_tiled_size", type=int, default=96) + parser.add_argument("--latent_tiled_overlap", type=int, default=32) + + parser.add_argument("--align_method", type=str, default="wavelet") + + parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.") + parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth") + + # training details + parser.add_argument("--output_dir", required=True) + parser.add_argument("--cache_dir", default=None,) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--resolution", type=int, default=512,) + parser.add_argument("--checkpointing_steps", type=int, default=500,) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",) + parser.add_argument("--gradient_checkpointing", action="store_true",) + + parser.add_argument("--dataloader_num_workers", type=int, default=0,) + parser.add_argument("--allow_tf32", action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--report_to", type=str, default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],) + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + parser.add_argument("--set_grads_to_none", action="store_true",) + + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + + +class PlainDataset(data.Dataset): + """Modified dataset based on the dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(PlainDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + if 'image_type' not in opt: + opt['image_type'] = 'png' + + # support multiple type of data: file path and meta data, remove support of lmdb + self.lr_paths = [] + if 'lr_path' in opt: + if isinstance(opt['lr_path'], str): + self.lr_paths.extend(sorted( + [str(x) for x in Path(opt['lr_path']).glob('*.png')] + + [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] + + [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')] + )) + else: + self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])])) + if len(opt['lr_path']) > 1: + for i in range(len(opt['lr_path'])-1): + self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])])) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + lr_path = self.lr_paths[index] + + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + lr_img_bytes = self.file_client.get(lr_path, 'gt') + except (IOError, OSError) as e: + # logger = get_root_logger() + # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()-1) + lr_path = self.lr_paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + + img_lr = imfrombytes(lr_img_bytes, float32=True) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0] + + return_d = {'lr': img_lr, 'lr_path': lr_path} + return return_d + + def __len__(self): + return len(self.lr_paths) + + +def lr_proc(config, batch, device): + im_lr = batch['lr'].cuda() + im_lr = im_lr.to(memory_format=torch.contiguous_format).float() + + ori_lr = im_lr + + im_lr = F.interpolate( + im_lr, + size=(im_lr.size(-2) * config.sf, + im_lr.size(-1) * config.sf), + mode='bicubic', + ) + + im_lr = im_lr.contiguous() + im_lr = im_lr * 2 - 1.0 + im_lr = torch.clamp(im_lr, -1.0, 1.0) + + ori_h, ori_w = im_lr.size(-2), im_lr.size(-1) + + pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h + pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w + im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect') + + return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w) diff --git a/src/my_utils/training_utils.py b/src/my_utils/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3f7d92f17f8b0f54f88bdefdb610744c115b37 --- /dev/null +++ b/src/my_utils/training_utils.py @@ -0,0 +1,532 @@ +import argparse +import json +from PIL import Image +from torchvision import transforms +import torch.nn.functional as F +from glob import glob + +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from pathlib import Path +from torch.utils import data as data + +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.data.transforms import paired_random_crop, triplet_random_crop +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian + +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + +def parse_args_paired_training(input_args=None): + """ + Parses command-line arguments used for configuring an paired session (pix2pix-Turbo). + This function sets up an argument parser to handle various training options. + + Returns: + argparse.Namespace: The parsed command-line arguments. + """ + parser = argparse.ArgumentParser() + # args for the loss function + parser.add_argument("--gan_disc_type", default="vagan") + parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s") + parser.add_argument("--lambda_gan", default=0.5, type=float) + parser.add_argument("--lambda_lpips", default=5.0, type=float) + parser.add_argument("--lambda_l2", default=2.0, type=float) + parser.add_argument("--base_config", default="./configs/sr.yaml", type=str) + + # validation eval args + parser.add_argument("--eval_freq", default=100, type=int) + parser.add_argument("--save_val", default=True, action="store_false") + parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation") + + parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.") + + # details about the model architecture + parser.add_argument("--sd_path") + parser.add_argument("--pretrained_path", type=str, default=None,) + parser.add_argument("--de_net_path") + parser.add_argument("--revision", type=str, default=None,) + parser.add_argument("--variant", type=str, default=None,) + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--lora_rank_unet", default=32, type=int) + parser.add_argument("--lora_rank_vae", default=16, type=int) + parser.add_argument("--neg_prob", default=0.05, type=float) + parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.") + parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth") + + # training details + parser.add_argument("--output_dir", required=True) + parser.add_argument("--cache_dir", default=None,) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--resolution", type=int, default=512,) + parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") + parser.add_argument("--num_training_epochs", type=int, default=50) + parser.add_argument("--max_train_steps", type=int, default=50000,) + parser.add_argument("--checkpointing_steps", type=int, default=500,) + parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.",) + parser.add_argument("--gradient_checkpointing", action="store_true",) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--lr_scheduler", type=str, default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "piecewise_constant", "constant_with_warmup"]' + ), + ) + parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--lr_num_cycles", type=int, default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=0.1, help="Power factor of the polynomial scheduler.") + + parser.add_argument("--dataloader_num_workers", type=int, default=0,) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--allow_tf32", action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--report_to", type=str, default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],) + parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") + parser.add_argument("--set_grads_to_none", action="store_true",) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + + +# @DATASET_REGISTRY.register(suffix='basicsr') +class PairedDataset(data.Dataset): + """Modified dataset based on the dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(PairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + if 'crop_size' in opt: + self.crop_size = opt['crop_size'] + else: + self.crop_size = 512 + if 'image_type' not in opt: + opt['image_type'] = 'png' + + # support multiple type of data: file path and meta data, remove support of lmdb + self.paths = [] + if 'meta_info' in opt: + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [v for v in paths] + if 'meta_num' in opt: + self.paths = sorted(self.paths)[:opt['meta_num']] + if 'gt_path' in opt: + if isinstance(opt['gt_path'], str): + # Use rglob to recursively search for images + self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).rglob('*.' + opt['image_type'])])) + else: + for path in opt['gt_path']: + self.paths.extend(sorted([str(x) for x in Path(path).rglob('*.' + opt['image_type'])])) + + # if 'gt_path' in opt: + # if isinstance(opt['gt_path'], str): + # self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])])) + # else: + # self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])])) + # if len(opt['gt_path']) > 1: + # for i in range(len(opt['gt_path'])-1): + # self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])])) + if 'imagenet_path' in opt: + class_list = os.listdir(opt['imagenet_path']) + for class_file in class_list: + self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')])) + if 'face_gt_path' in opt: + if isinstance(opt['face_gt_path'], str): + face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])]) + self.paths.extend(face_list[:opt['num_face']]) + else: + face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])]) + self.paths.extend(face_list[:opt['num_face']]) + if len(opt['face_gt_path']) > 1: + for i in range(len(opt['face_gt_path'])-1): + self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']]) + + # limit number of pictures for test + if 'num_pic' in opt: + if 'val' or 'test' in opt: + random.shuffle(self.paths) + self.paths = self.paths[:opt['num_pic']] + else: + self.paths = self.paths[:opt['num_pic']] + + if 'mul_num' in opt: + self.paths = self.paths * opt['mul_num'] + # print('>>>>>>>>>>>>>>>>>>>>>') + # print(self.paths) + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + # logger = get_root_logger() + # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()-1) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + # filter the dataset and remove images with too low quality + img_size = os.path.getsize(gt_path) + img_size = img_size / 1024 + + while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100: + index = random.randint(0, self.__len__()-1) + gt_path = self.paths[index] + + time.sleep(0.1) # sleep 1s for occasional server congestion + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + img_size = os.path.getsize(gt_path) + img_size = img_size / 1024 + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = self.crop_size + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + # top = (h - crop_pad_size) // 2 -1 + # left = (w - crop_pad_size) // 2 -1 + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) + + +def randn_cropinput(lq, gt, base_size=[64, 128, 256, 512]): + cur_size_h = random.choice(base_size) + cur_size_w = random.choice(base_size) + init_h = lq.size(-2)//2 + init_w = lq.size(-1)//2 + lq = lq[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2] + gt = gt[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2] + assert lq.size(-1)>=64 + assert lq.size(-2)>=64 + return [lq, gt] + + +def degradation_proc(configs, batch, device, val=False, use_usm=False, resize_lq=True, random_size=False): + + """Degradation pipeline, modified from Real-ESRGAN: + https://github.com/xinntao/Real-ESRGAN + """ + + jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + usm_sharpener = USMSharp().cuda() # do usm sharpening + + im_gt = batch['gt'].cuda() + if use_usm: + im_gt = usm_sharpener(im_gt) + im_gt = im_gt.to(memory_format=torch.contiguous_format).float() + kernel1 = batch['kernel1'].cuda() + kernel2 = batch['kernel2'].cuda() + sinc_kernel = batch['sinc_kernel'].cuda() + + ori_h, ori_w = im_gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(im_gt, kernel1) + # random resize + updown_type = random.choices( + ['up', 'down', 'keep'], + configs.degradation['resize_prob'], + )[0] + if updown_type == 'up': + scale = random.uniform(1, configs.degradation['resize_range'][1]) + elif updown_type == 'down': + scale = random.uniform(configs.degradation['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = configs.degradation['gray_noise_prob'] + if random.random() < configs.degradation['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, + sigma_range=configs.degradation['noise_range'], + clip=True, + rounds=False, + gray_prob=gray_noise_prob, + ) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=configs.degradation['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if random.random() < configs.degradation['second_blur_prob']: + out = filter2D(out, kernel2) + # random resize + updown_type = random.choices( + ['up', 'down', 'keep'], + configs.degradation['resize_prob2'], + )[0] + if updown_type == 'up': + scale = random.uniform(1, configs.degradation['resize_range2'][1]) + elif updown_type == 'down': + scale = random.uniform(configs.degradation['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, + size=(int(ori_h / configs.sf * scale), + int(ori_w / configs.sf * scale)), + mode=mode, + ) + # add noise + gray_noise_prob = configs.degradation['gray_noise_prob2'] + if random.random() < configs.degradation['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, + sigma_range=configs.degradation['noise_range2'], + clip=True, + rounds=False, + gray_prob=gray_noise_prob, + ) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=configs.degradation['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False, + ) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if random.random() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, + size=(ori_h // configs.sf, + ori_w // configs.sf), + mode=mode, + ) + out = filter2D(out, sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, + size=(ori_h // configs.sf, + ori_w // configs.sf), + mode=mode, + ) + out = filter2D(out, sinc_kernel) + + # clamp and round + im_lq = torch.clamp(out, 0, 1.0) + + # random crop + gt_size = configs.degradation['gt_size'] + im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, configs.sf) + lq, gt = im_lq, im_gt + ori_lq = im_lq + + if resize_lq: + lq = F.interpolate( + lq, + size=(gt.size(-2), + gt.size(-1)), + mode='bicubic', + ) + + if random.random() < configs.degradation['no_degradation_prob'] or torch.isnan(lq).any(): + lq = gt + + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + lq = lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + lq = lq * 2 - 1.0 # TODO 0~1? + gt = gt * 2 - 1.0 + + if random_size: + lq, gt = randn_cropinput(lq, gt) + + lq = torch.clamp(lq, -1.0, 1.0) + + return lq.to(device), gt.to(device), ori_lq.to(device) diff --git a/src/my_utils/utils.py b/src/my_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1bdc214ecc7defaa76cf47b592174ab1d580bf --- /dev/null +++ b/src/my_utils/utils.py @@ -0,0 +1,213 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def instantiate_from_config_sr(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/src/my_utils/vaehook.py b/src/my_utils/vaehook.py new file mode 100644 index 0000000000000000000000000000000000000000..2975dea13fb55fca903df6d26501e3c6ab1541c5 --- /dev/null +++ b/src/my_utils/vaehook.py @@ -0,0 +1,828 @@ +# ------------------------------------------------------------------------ +# +# Ultimate VAE Tile Optimization +# +# Introducing a revolutionary new optimization designed to make +# the VAE work with giant images on limited VRAM! +# Say goodbye to the frustration of OOM and hello to seamless output! +# +# ------------------------------------------------------------------------ +# +# This script is a wild hack that splits the image into tiles, +# encodes each tile separately, and merges the result back together. +# +# Advantages: +# - The VAE can now work with giant images on limited VRAM +# (~10 GB for 8K images!) +# - The merged output is completely seamless without any post-processing. +# +# Drawbacks: +# - Giant RAM needed. To store the intermediate results for a 4096x4096 +# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192 +# you need 128 GB RAM machine (it consumes ~100 GB) +# - NaNs always appear in for 8k images when you use fp16 (half) VAE +# You must use --no-half-vae to disable half VAE for that giant image. +# - Slow speed. With default tile size, it takes around 50/200 seconds +# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode +# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.) +# - The gradient calculation is not compatible with this hack. It +# will break any backward() or torch.autograd.grad() that passes VAE. +# (But you can still use the VAE to generate training data.) +# +# How it works: +# 1) The image is split into tiles. +# - To ensure perfect results, each tile is padded with 32 pixels +# on each side. +# - Then the conv2d/silu/upsample/downsample can produce identical +# results to the original image without splitting. +# 2) The original forward is decomposed into a task queue and a task worker. +# - The task queue is a list of functions that will be executed in order. +# - The task worker is a loop that executes the tasks in the queue. +# 3) The task queue is executed for each tile. +# - Current tile is sent to GPU. +# - local operations are directly executed. +# - Group norm calculation is temporarily suspended until the mean +# and var of all tiles are calculated. +# - The residual is pre-calculated and stored and addded back later. +# - When need to go to the next tile, the current tile is send to cpu. +# 4) After all tiles are processed, tiles are merged on cpu and return. +# +# Enjoy! +# +# @author: LI YI @ Nanyang Technological University - Singapore +# @date: 2023-03-02 +# @license: MIT License +# +# Please give me a star if you like this project! +# +# ------------------------------------------------------------------------- + +import gc +from time import time +import math +from tqdm import tqdm + +import torch +import torch.version +import torch.nn.functional as F +from einops import rearrange +import os +import sys +sys.path.append(os.getcwd()) +import my_utils.devices as devices + +try: + import xformers + import xformers.ops +except ImportError: + pass + +sd_flag = False + +def get_recommend_encoder_tile_size(): + if torch.cuda.is_available(): + total_memory = torch.cuda.get_device_properties( + devices.device).total_memory // 2**20 + if total_memory > 16*1000: + ENCODER_TILE_SIZE = 3072 + elif total_memory > 12*1000: + ENCODER_TILE_SIZE = 2048 + elif total_memory > 8*1000: + ENCODER_TILE_SIZE = 1536 + else: + ENCODER_TILE_SIZE = 960 + else: + ENCODER_TILE_SIZE = 512 + return ENCODER_TILE_SIZE + + +def get_recommend_decoder_tile_size(): + if torch.cuda.is_available(): + total_memory = torch.cuda.get_device_properties( + devices.device).total_memory // 2**20 + if total_memory > 30*1000: + DECODER_TILE_SIZE = 256 + elif total_memory > 16*1000: + DECODER_TILE_SIZE = 192 + elif total_memory > 12*1000: + DECODER_TILE_SIZE = 128 + elif total_memory > 8*1000: + DECODER_TILE_SIZE = 96 + else: + DECODER_TILE_SIZE = 64 + else: + DECODER_TILE_SIZE = 64 + return DECODER_TILE_SIZE + + +if 'global const': + DEFAULT_ENABLED = False + DEFAULT_MOVE_TO_GPU = False + DEFAULT_FAST_ENCODER = True + DEFAULT_FAST_DECODER = True + DEFAULT_COLOR_FIX = 0 + DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size() + DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size() + + +# inplace version of silu +def inplace_nonlinearity(x): + # Test: fix for Nans + return F.silu(x, inplace=True) + +# extracted from ldm.modules.diffusionmodules.model + +# from diffusers lib +def attn_forward_new(self, h_): + batch_size, channel, height, width = h_.shape + hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2) + + attention_mask = None + encoder_hidden_states = None + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + query = self.head_to_batch_dim(query) + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) + + attention_probs = self.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states + +def attn_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h*w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return h_ + + +def xformer_attn_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return out + + +def attn2task(task_queue, net): + if False: #isinstance(net, AttnBlock): + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.norm)) + task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) + task_queue.append(['add_res', None]) + elif False: #isinstance(net, MemoryEfficientAttnBlock): + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.norm)) + task_queue.append( + ('attn', lambda x, net=net: xformer_attn_forward(net, x))) + task_queue.append(['add_res', None]) + else: + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.group_norm)) + task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x))) + task_queue.append(['add_res', None]) + +def resblock2task(queue, block): + """ + Turn a ResNetBlock into a sequence of tasks and append to the task queue + + @param queue: the target task queue + @param block: ResNetBlock + + """ + if block.in_channels != block.out_channels: + if sd_flag: + if block.use_conv_shortcut: + queue.append(('store_res', block.conv_shortcut)) + else: + queue.append(('store_res', block.nin_shortcut)) + else: + if block.use_in_shortcut: + queue.append(('store_res', block.conv_shortcut)) + else: + queue.append(('store_res', block.nin_shortcut)) + + else: + queue.append(('store_res', lambda x: x)) + queue.append(('pre_norm', block.norm1)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv1', block.conv1)) + queue.append(('pre_norm', block.norm2)) + queue.append(('silu', inplace_nonlinearity)) + queue.append(('conv2', block.conv2)) + queue.append(['add_res', None]) + + + +def build_sampling(task_queue, net, is_decoder): + """ + Build the sampling part of a task queue + @param task_queue: the target task queue + @param net: the network + @param is_decoder: currently building decoder or encoder + """ + if is_decoder: + # resblock2task(task_queue, net.mid.block_1) + # attn2task(task_queue, net.mid.attn_1) + # resblock2task(task_queue, net.mid.block_2) + # resolution_iter = reversed(range(net.num_resolutions)) + # block_ids = net.num_res_blocks + 1 + # condition = 0 + # module = net.up + # func_name = 'upsample' + resblock2task(task_queue, net.mid_block.resnets[0]) + attn2task(task_queue, net.mid_block.attentions[0]) + resblock2task(task_queue, net.mid_block.resnets[1]) + resolution_iter = (range(len(net.up_blocks))) # range(0,4) + block_ids = 2 + 1 + condition = len(net.up_blocks) - 1 + module = net.up_blocks + func_name = 'upsamplers' + else: + # resolution_iter = range(net.num_resolutions) + # block_ids = net.num_res_blocks + # condition = net.num_resolutions - 1 + # module = net.down + # func_name = 'downsample' + resolution_iter = (range(len(net.down_blocks))) # range(0,4) + block_ids = 2 + condition = len(net.down_blocks) - 1 + module = net.down_blocks + func_name = 'downsamplers' + + + for i_level in resolution_iter: + for i_block in range(block_ids): + resblock2task(task_queue, module[i_level].resnets[i_block]) + if i_level != condition: + if is_decoder: + task_queue.append((func_name, module[i_level].upsamplers[0])) + else: + task_queue.append((func_name, module[i_level].downsamplers[0])) + + if not is_decoder: + resblock2task(task_queue, net.mid_block.resnets[0]) + attn2task(task_queue, net.mid_block.attentions[0]) + resblock2task(task_queue, net.mid_block.resnets[1]) + + +def build_task_queue(net, is_decoder): + """ + Build a single task queue for the encoder or decoder + @param net: the VAE decoder or encoder network + @param is_decoder: currently building decoder or encoder + @return: the task queue + """ + task_queue = [] + task_queue.append(('conv_in', net.conv_in)) + + # construct the sampling part of the task queue + # because encoder and decoder share the same architecture, we extract the sampling part + build_sampling(task_queue, net, is_decoder) + if is_decoder and not sd_flag: + net.give_pre_end = False + net.tanh_out = False + + if not is_decoder or not net.give_pre_end: + if sd_flag: + task_queue.append(('pre_norm', net.norm_out)) + else: + task_queue.append(('pre_norm', net.conv_norm_out)) + task_queue.append(('silu', inplace_nonlinearity)) + task_queue.append(('conv_out', net.conv_out)) + if is_decoder and net.tanh_out: + task_queue.append(('tanh', torch.tanh)) + + return task_queue + + +def clone_task_queue(task_queue): + """ + Clone a task queue + @param task_queue: the task queue to be cloned + @return: the cloned task queue + """ + return [[item for item in task] for task in task_queue] + + +def get_var_mean(input, num_groups, eps=1e-6): + """ + Get mean and var for group norm + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view( + 1, int(b * num_groups), channel_in_group, *input.size()[2:]) + var, mean = torch.var_mean( + input_reshaped, dim=[0, 2, 3, 4], unbiased=False) + return var, mean + + +def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): + """ + Custom group norm with fixed mean and var + + @param input: input tensor + @param num_groups: number of groups. by default, num_groups = 32 + @param mean: mean, must be pre-calculated by get_var_mean + @param var: var, must be pre-calculated by get_var_mean + @param weight: weight, should be fetched from the original group norm + @param bias: bias, should be fetched from the original group norm + @param eps: epsilon, by default, eps = 1e-6 to match the original group norm + + @return: normalized tensor + """ + b, c = input.size(0), input.size(1) + channel_in_group = int(c/num_groups) + input_reshaped = input.contiguous().view( + 1, int(b * num_groups), channel_in_group, *input.size()[2:]) + + out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, + training=False, momentum=0, eps=eps) + + out = out.view(b, c, *input.size()[2:]) + + # post affine transform + if weight is not None: + out *= weight.view(1, -1, 1, 1) + if bias is not None: + out += bias.view(1, -1, 1, 1) + return out + + +def crop_valid_region(x, input_bbox, target_bbox, is_decoder): + """ + Crop the valid region from the tile + @param x: input tile + @param input_bbox: original input bounding box + @param target_bbox: output bounding box + @param scale: scale factor + @return: cropped tile + """ + padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] + margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] + return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] + +# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ + + +def perfcount(fn): + def wrapper(*args, **kwargs): + ts = time() + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(devices.device) + devices.torch_gc() + gc.collect() + + ret = fn(*args, **kwargs) + + devices.torch_gc() + gc.collect() + if torch.cuda.is_available(): + vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 + torch.cuda.reset_peak_memory_stats(devices.device) + print( + f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') + else: + print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') + + return ret + return wrapper + +# copy end :) + + +class GroupNormParam: + def __init__(self): + self.var_list = [] + self.mean_list = [] + self.pixel_list = [] + self.weight = None + self.bias = None + + def add_tile(self, tile, layer): + var, mean = get_var_mean(tile, 32) + # For giant images, the variance can be larger than max float16 + # In this case we create a copy to float32 + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # ============= DEBUG: test for infinite ============= + # if torch.isinf(var).any(): + # print('var: ', var) + # ==================================================== + self.var_list.append(var) + self.mean_list.append(mean) + self.pixel_list.append( + tile.shape[2]*tile.shape[3]) + if hasattr(layer, 'weight'): + self.weight = layer.weight + self.bias = layer.bias + else: + self.weight = None + self.bias = None + + def summary(self): + """ + summarize the mean and var and return a function + that apply group norm on each tile + """ + if len(self.var_list) == 0: + return None + var = torch.vstack(self.var_list) + mean = torch.vstack(self.mean_list) + max_value = max(self.pixel_list) + pixels = torch.tensor( + self.pixel_list, dtype=torch.float32, device=devices.device) / max_value + sum_pixels = torch.sum(pixels) + pixels = pixels.unsqueeze( + 1) / sum_pixels + var = torch.sum( + var * pixels, dim=0) + mean = torch.sum( + mean * pixels, dim=0) + return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) + + @staticmethod + def from_tile(tile, norm): + """ + create a function from a single tile without summary + """ + var, mean = get_var_mean(tile, 32) + if var.dtype == torch.float16 and var.isinf().any(): + fp32_tile = tile.float() + var, mean = get_var_mean(fp32_tile, 32) + # if it is a macbook, we need to convert back to float16 + if var.device.type == 'mps': + # clamp to avoid overflow + var = torch.clamp(var, 0, 60000) + var = var.half() + mean = mean.half() + if hasattr(norm, 'weight'): + weight = norm.weight + bias = norm.bias + else: + weight = None + bias = None + + def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): + return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) + return group_norm_func + + +class VAEHook: + def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False): + self.net = net # encoder | decoder + self.tile_size = tile_size + self.is_decoder = is_decoder + self.fast_mode = (fast_encoder and not is_decoder) or ( + fast_decoder and is_decoder) + self.color_fix = color_fix and not is_decoder + self.to_gpu = to_gpu + self.pad = 11 if is_decoder else 32 + + def __call__(self, x): + B, C, H, W = x.shape + original_device = next(self.net.parameters()).device + try: + if self.to_gpu: + self.net.to(devices.get_optimal_device()) + if max(H, W) <= self.pad * 2 + self.tile_size: + print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") + return self.net.original_forward(x) + else: + return self.vae_tile_forward(x) + finally: + self.net.to(original_device) + + def get_best_tile_size(self, lowerbound, upperbound): + """ + Get the best tile size for GPU memory + """ + divider = 32 + while divider >= 2: + remainer = lowerbound % divider + if remainer == 0: + return lowerbound + candidate = lowerbound - remainer + divider + if candidate <= upperbound: + return candidate + divider //= 2 + return lowerbound + + def split_tiles(self, h, w): + """ + Tool function to split the image into tiles + @param h: height of the image + @param w: width of the image + @return: tile_input_bboxes, tile_output_bboxes + """ + tile_input_bboxes, tile_output_bboxes = [], [] + tile_size = self.tile_size + pad = self.pad + num_height_tiles = math.ceil((h - 2 * pad) / tile_size) + num_width_tiles = math.ceil((w - 2 * pad) / tile_size) + # If any of the numbers are 0, we let it be 1 + # This is to deal with long and thin images + num_height_tiles = max(num_height_tiles, 1) + num_width_tiles = max(num_width_tiles, 1) + + # Suggestions from https://github.com/Kahsolt: auto shrink the tile size + real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) + real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) + real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) + real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) + + print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') + + for i in range(num_height_tiles): + for j in range(num_width_tiles): + # bbox: [x1, x2, y1, y2] + # the padding is is unnessary for image borders. So we directly start from (32, 32) + input_bbox = [ + pad + j * real_tile_width, + min(pad + (j + 1) * real_tile_width, w), + pad + i * real_tile_height, + min(pad + (i + 1) * real_tile_height, h), + ] + + # if the output bbox is close to the image boundary, we extend it to the image boundary + output_bbox = [ + input_bbox[0] if input_bbox[0] > pad else 0, + input_bbox[1] if input_bbox[1] < w - pad else w, + input_bbox[2] if input_bbox[2] > pad else 0, + input_bbox[3] if input_bbox[3] < h - pad else h, + ] + + # scale to get the final output bbox + output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] + tile_output_bboxes.append(output_bbox) + + # indistinguishable expand the input bbox by pad pixels + tile_input_bboxes.append([ + max(0, input_bbox[0] - pad), + min(w, input_bbox[1] + pad), + max(0, input_bbox[2] - pad), + min(h, input_bbox[3] + pad), + ]) + + return tile_input_bboxes, tile_output_bboxes + + @torch.no_grad() + def estimate_group_norm(self, z, task_queue, color_fix): + device = z.device + tile = z + last_id = len(task_queue) - 1 + while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': + last_id -= 1 + if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': + raise ValueError('No group norm found in the task queue') + # estimate until the last group norm + for i in range(last_id + 1): + task = task_queue[i] + if task[0] == 'pre_norm': + group_norm_func = GroupNormParam.from_tile(tile, task[1]) + task_queue[i] = ('apply_norm', group_norm_func) + if i == last_id: + return True + tile = group_norm_func(tile) + elif task[0] == 'store_res': + task_id = i + 1 + while task_id < last_id and task_queue[task_id][0] != 'add_res': + task_id += 1 + if task_id >= last_id: + continue + task_queue[task_id][1] = task[1](tile) + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + elif color_fix and task[0] == 'downsample': + for j in range(i, last_id + 1): + if task_queue[j][0] == 'store_res': + task_queue[j] = ('store_res_cpu', task_queue[j][1]) + return True + else: + tile = task[1](tile) + try: + devices.test_for_nans(tile, "vae") + except: + print(f'Nan detected in fast mode estimation. Fast mode disabled.') + return False + + raise IndexError('Should not reach here') + + @perfcount + @torch.no_grad() + def vae_tile_forward(self, z): + """ + Decode a latent vector z into an image in a tiled manner. + @param z: latent vector + @return: image + """ + device = next(self.net.parameters()).device + net = self.net + tile_size = self.tile_size + is_decoder = self.is_decoder + + z = z.detach() # detach the input to avoid backprop + + N, height, width = z.shape[0], z.shape[2], z.shape[3] + net.last_z_shape = z.shape + + # Split the input into tiles and build a task queue for each tile + print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') + + in_bboxes, out_bboxes = self.split_tiles(height, width) + + # Prepare tiles by split the input latents + tiles = [] + for input_bbox in in_bboxes: + tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() + tiles.append(tile) + + num_tiles = len(tiles) + num_completed = 0 + + # Build task queues + single_task_queue = build_task_queue(net, is_decoder) + #print(single_task_queue) + if self.fast_mode: + # Fast mode: downsample the input image to the tile size, + # then estimate the group norm parameters on the downsampled image + scale_factor = tile_size / max(height, width) + z = z.to(device) + downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') + # use nearest-exact to keep statictics as close as possible + print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') + + # ======= Special thanks to @Kahsolt for distribution shift issue ======= # + # The downsampling will heavily distort its mean and std, so we need to recover it. + std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) + std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) + downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old + del std_old, mean_old, std_new, mean_new + # occasionally the std_new is too small or too large, which exceeds the range of float16 + # so we need to clamp it to max z's range. + downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) + estimate_task_queue = clone_task_queue(single_task_queue) + if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): + single_task_queue = estimate_task_queue + del downsampled_z + + task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] + + # Dummy result + result = None + result_approx = None + #try: + # with devices.autocast(): + # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() + #except: pass + # Free memory of input latent tensor + del z + + # Task queue execution + pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") + + # execute the task back and forth when switch tiles so that we always + # keep one tile on the GPU to reduce unnecessary data transfer + forward = True + interrupted = False + #state.interrupted = interrupted + while True: + #if state.interrupted: interrupted = True ; break + + group_norm_param = GroupNormParam() + for i in range(num_tiles) if forward else reversed(range(num_tiles)): + #if state.interrupted: interrupted = True ; break + + tile = tiles[i].to(device) + input_bbox = in_bboxes[i] + task_queue = task_queues[i] + + interrupted = False + while len(task_queue) > 0: + #if state.interrupted: interrupted = True ; break + + # DEBUG: current task + # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) + task = task_queue.pop(0) + if task[0] == 'pre_norm': + group_norm_param.add_tile(tile, task[1]) + break + elif task[0] == 'store_res' or task[0] == 'store_res_cpu': + task_id = 0 + res = task[1](tile) + if not self.fast_mode or task[0] == 'store_res_cpu': + res = res.cpu() + while task_queue[task_id][0] != 'add_res': + task_id += 1 + task_queue[task_id][1] = res + elif task[0] == 'add_res': + tile += task[1].to(device) + task[1] = None + else: + tile = task[1](tile) + pbar.update(1) + + if interrupted: break + + # check for NaNs in the tile. + # If there are NaNs, we abort the process to save user's time + #devices.test_for_nans(tile, "vae") + + #print(tiles[i].shape, tile.shape, i, num_tiles) + if len(task_queue) == 0: + tiles[i] = None + num_completed += 1 + if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically + result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) + result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) + del tile + elif i == num_tiles - 1 and forward: + forward = False + tiles[i] = tile + elif i == 0 and not forward: + forward = True + tiles[i] = tile + else: + tiles[i] = tile.cpu() + del tile + + if interrupted: break + if num_completed == num_tiles: break + + # insert the group norm task to the head of each task queue + group_norm_func = group_norm_param.summary() + if group_norm_func is not None: + for i in range(num_tiles): + task_queue = task_queues[i] + task_queue.insert(0, ('apply_norm', group_norm_func)) + + # Done! + pbar.close() + return result if result is not None else result_approx.to(device) \ No newline at end of file diff --git a/src/s3diff.py b/src/s3diff.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2f03d14dd07e5d0ab85dfc1e90a342f6d3a1b6 --- /dev/null +++ b/src/s3diff.py @@ -0,0 +1,305 @@ +import os +import re +import requests +import sys +import copy +import numpy as np +from tqdm import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel +from peft import LoraConfig, get_peft_model +p = "src/" +sys.path.append(p) +from model import make_1step_sched, my_lora_fwd +from basicsr.archs.arch_util import default_init_weights + +def get_layer_number(module_name): + base_layers = { + 'down_blocks': 0, + 'mid_block': 4, + 'up_blocks': 5 + } + + if module_name == 'conv_out': + return 9 + + base_layer = None + for key in base_layers: + if key in module_name: + base_layer = base_layers[key] + break + + if base_layer is None: + return None + + additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name)) + final_layer = base_layer + additional_layers + return final_layer + + +class S3Diff(torch.nn.Module): + def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda() + self.sched = make_1step_sched(sd_path) + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet") + + target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$" + target_modules_unet = [ + "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", + "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj" + ] + + num_embeddings = 64 + self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False) + + self.vae_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.unet_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.vae_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.unet_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2) + self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2) + + default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \ + self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5) + + # vae + self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim) + self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim) + + if pretrained_path is not None: + sd = torch.load(pretrained_path, map_location="cpu") + vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + _sd_vae = vae.state_dict() + for k in sd["state_dict_vae"]: + _sd_vae[k] = sd["state_dict_vae"][k] + vae.load_state_dict(_sd_vae) + + unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) + unet.add_adapter(unet_lora_config) + _sd_unet = unet.state_dict() + for k in sd["state_dict_unet"]: + _sd_unet[k] = sd["state_dict_unet"][k] + unet.load_state_dict(_sd_unet) + + _vae_de_mlp = self.vae_de_mlp.state_dict() + for k in sd["state_dict_vae_de_mlp"]: + _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k] + self.vae_de_mlp.load_state_dict(_vae_de_mlp) + + _unet_de_mlp = self.unet_de_mlp.state_dict() + for k in sd["state_dict_unet_de_mlp"]: + _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k] + self.unet_de_mlp.load_state_dict(_unet_de_mlp) + + _vae_block_mlp = self.vae_block_mlp.state_dict() + for k in sd["state_dict_vae_block_mlp"]: + _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k] + self.vae_block_mlp.load_state_dict(_vae_block_mlp) + + _unet_block_mlp = self.unet_block_mlp.state_dict() + for k in sd["state_dict_unet_block_mlp"]: + _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k] + self.unet_block_mlp.load_state_dict(_unet_block_mlp) + + _vae_fuse_mlp = self.vae_fuse_mlp.state_dict() + for k in sd["state_dict_vae_fuse_mlp"]: + _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k] + self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp) + + _unet_fuse_mlp = self.unet_fuse_mlp.state_dict() + for k in sd["state_dict_unet_fuse_mlp"]: + _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k] + self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp) + + self.W = nn.Parameter(sd["w"], requires_grad=False) + + embeddings_state_dict = sd["state_embeddings"] + self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block']) + self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block']) + else: + print("Initializing model with random weights") + vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian", + target_modules=target_modules_vae) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian", + target_modules=target_modules_unet + ) + unet.add_adapter(unet_lora_config) + + self.lora_rank_unet = lora_rank_unet + self.lora_rank_vae = lora_rank_vae + self.target_modules_vae = target_modules_vae + self.target_modules_unet = target_modules_unet + + self.vae_lora_layers = [] + for name, module in vae.named_modules(): + if 'base_layer' in name: + self.vae_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in vae.named_modules(): + if name in self.vae_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_lora_layers = [] + for name, module in unet.named_modules(): + if 'base_layer' in name: + self.unet_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in unet.named_modules(): + if name in self.unet_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers} + + unet.to("cuda") + vae.to("cuda") + self.unet, self.vae = unet, vae + self.timesteps = torch.tensor([999], device="cuda").long() + self.text_encoder.requires_grad_(False) + + def set_eval(self): + self.unet.eval() + self.vae.eval() + self.vae_de_mlp.eval() + self.unet_de_mlp.eval() + self.vae_block_mlp.eval() + self.unet_block_mlp.eval() + self.vae_fuse_mlp.eval() + self.unet_fuse_mlp.eval() + + self.vae_block_embeddings.requires_grad_(False) + self.unet_block_embeddings.requires_grad_(False) + + self.unet.requires_grad_(False) + self.vae.requires_grad_(False) + + def set_train(self): + self.unet.train() + self.vae.train() + self.vae_de_mlp.train() + self.unet_de_mlp.train() + self.vae_block_mlp.train() + self.unet_block_mlp.train() + self.vae_fuse_mlp.train() + self.unet_fuse_mlp.train() + + self.vae_block_embeddings.requires_grad_(True) + self.unet_block_embeddings.requires_grad_(True) + + for n, _p in self.unet.named_parameters(): + if "lora" in n: + _p.requires_grad = True + + self.unet.conv_in.requires_grad_(True) + + for n, _p in self.vae.named_parameters(): + if "lora" in n: + _p.requires_grad = True + + def forward(self, c_t, deg_score, prompt): + + if prompt is not None: + # encode the text prompt + caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() + caption_enc = self.text_encoder(caption_tokens)[0] + else: + caption_enc = self.text_encoder(prompt_tokens)[0] + + # degradation fourier embedding + deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi + deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) + deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) + + # degradation mlp forward + vae_de_c_embed = self.vae_de_mlp(deg_proj) + unet_de_c_embed = self.unet_de_mlp(deg_proj) + + # block embedding mlp forward + vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight) + unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight) + vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \ + vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1)) + unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \ + unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1)) + + for layer_name, module in self.vae.named_modules(): + if layer_name in self.vae_lora_layers: + split_name = layer_name.split(".") + if split_name[1] == 'down_blocks': + block_id = int(split_name[2]) + vae_embed = vae_embeds[:, block_id] + elif split_name[1] == 'mid_block': + vae_embed = vae_embeds[:, -2] + else: + vae_embed = vae_embeds[:, -1] + module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae) + + for layer_name, module in self.unet.named_modules(): + if layer_name in self.unet_lora_layers: + split_name = layer_name.split(".") + + if split_name[0] == 'down_blocks': + block_id = int(split_name[1]) + unet_embed = unet_embeds[:, block_id] + elif split_name[0] == 'mid_block': + unet_embed = unet_embeds[:, 4] + elif split_name[0] == 'up_blocks': + block_id = int(split_name[1]) + 5 + unet_embed = unet_embeds[:, block_id] + else: + unet_embed = unet_embeds[:, -1] + module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet) + + encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor + model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample + x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample + output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) + + return output_image + + def save_model(self, outf): + sd = {} + sd["unet_lora_target_modules"] = self.target_modules_unet + sd["vae_lora_target_modules"] = self.target_modules_vae + sd["rank_unet"] = self.lora_rank_unet + sd["rank_vae"] = self.lora_rank_vae + sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} + sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k} + sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()} + sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()} + sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()} + sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()} + sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()} + sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()} + sd["w"] = self.W + + sd["state_embeddings"] = { + "state_dict_vae_block": self.vae_block_embeddings.state_dict(), + "state_dict_unet_block": self.unet_block_embeddings.state_dict(), + } + + torch.save(sd, outf) diff --git a/src/s3diff_cfg.py b/src/s3diff_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..54a7b2353a702fffb01edb1b4a63a1e6d77dc389 --- /dev/null +++ b/src/s3diff_cfg.py @@ -0,0 +1,316 @@ +import os +import re +import requests +import sys +import copy +import numpy as np +from tqdm import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel +from peft import LoraConfig, get_peft_model +p = "src/" +sys.path.append(p) +from model import make_1step_sched, my_lora_fwd +from basicsr.archs.arch_util import default_init_weights + + +def get_layer_number(module_name): + base_layers = { + 'down_blocks': 0, + 'mid_block': 4, + 'up_blocks': 5 + } + + if module_name == 'conv_out': + return 9 + + base_layer = None + for key in base_layers: + if key in module_name: + base_layer = base_layers[key] + break + + if base_layer is None: + return None + + additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name)) + final_layer = base_layer + additional_layers + return final_layer + + +class S3Diff(torch.nn.Module): + def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=8, lora_rank_vae=4, block_embedding_dim=64): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda() + self.sched = make_1step_sched(sd_path) + self.guidance_scale = 1.07 + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet") + + target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$" + target_modules_unet = [ + "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", + "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj" + ] + + num_embeddings = 64 + self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False) + + self.vae_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.unet_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.vae_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.unet_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2) + self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2) + + default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \ + self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5) + + # vae + self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim) + self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim) + + if pretrained_path is not None: + sd = torch.load(pretrained_path, map_location="cpu") + vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + _sd_vae = vae.state_dict() + for k in sd["state_dict_vae"]: + _sd_vae[k] = sd["state_dict_vae"][k] + vae.load_state_dict(_sd_vae) + + unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) + unet.add_adapter(unet_lora_config) + _sd_unet = unet.state_dict() + for k in sd["state_dict_unet"]: + _sd_unet[k] = sd["state_dict_unet"][k] + unet.load_state_dict(_sd_unet) + + _vae_de_mlp = self.vae_de_mlp.state_dict() + for k in sd["state_dict_vae_de_mlp"]: + _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k] + self.vae_de_mlp.load_state_dict(_vae_de_mlp) + + _unet_de_mlp = self.unet_de_mlp.state_dict() + for k in sd["state_dict_unet_de_mlp"]: + _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k] + self.unet_de_mlp.load_state_dict(_unet_de_mlp) + + _vae_block_mlp = self.vae_block_mlp.state_dict() + for k in sd["state_dict_vae_block_mlp"]: + _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k] + self.vae_block_mlp.load_state_dict(_vae_block_mlp) + + _unet_block_mlp = self.unet_block_mlp.state_dict() + for k in sd["state_dict_unet_block_mlp"]: + _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k] + self.unet_block_mlp.load_state_dict(_unet_block_mlp) + + _vae_fuse_mlp = self.vae_fuse_mlp.state_dict() + for k in sd["state_dict_vae_fuse_mlp"]: + _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k] + self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp) + + _unet_fuse_mlp = self.unet_fuse_mlp.state_dict() + for k in sd["state_dict_unet_fuse_mlp"]: + _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k] + self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp) + + self.W = nn.Parameter(sd["w"], requires_grad=False) + + embeddings_state_dict = sd["state_embeddings"] + self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block']) + self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block']) + else: + print("Initializing model with random weights") + vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian", + target_modules=target_modules_vae) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian", + target_modules=target_modules_unet + ) + unet.add_adapter(unet_lora_config) + + self.lora_rank_unet = lora_rank_unet + self.lora_rank_vae = lora_rank_vae + self.target_modules_vae = target_modules_vae + self.target_modules_unet = target_modules_unet + + self.vae_lora_layers = [] + for name, module in vae.named_modules(): + if 'base_layer' in name: + self.vae_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in vae.named_modules(): + if name in self.vae_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_lora_layers = [] + for name, module in unet.named_modules(): + if 'base_layer' in name: + self.unet_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in unet.named_modules(): + if name in self.unet_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers} + + unet.to("cuda") + vae.to("cuda") + self.unet, self.vae = unet, vae + self.timesteps = torch.tensor([999], device="cuda").long() + self.text_encoder.requires_grad_(False) + + def set_eval(self): + self.unet.eval() + self.vae.eval() + self.vae_de_mlp.eval() + self.unet_de_mlp.eval() + self.vae_block_mlp.eval() + self.unet_block_mlp.eval() + self.vae_fuse_mlp.eval() + self.unet_fuse_mlp.eval() + + self.vae_block_embeddings.requires_grad_(False) + self.unet_block_embeddings.requires_grad_(False) + + self.unet.requires_grad_(False) + self.vae.requires_grad_(False) + + def set_train(self): + self.unet.train() + self.vae.train() + self.vae_de_mlp.train() + self.unet_de_mlp.train() + self.vae_block_mlp.train() + self.unet_block_mlp.train() + self.vae_fuse_mlp.train() + self.unet_fuse_mlp.train() + + self.vae_block_embeddings.requires_grad_(True) + self.unet_block_embeddings.requires_grad_(True) + + for n, _p in self.unet.named_parameters(): + if "lora" in n: + _p.requires_grad = True + self.unet.conv_in.requires_grad_(True) + + for n, _p in self.vae.named_parameters(): + if "lora" in n: + _p.requires_grad = True + + def forward(self, c_t, deg_score, pos_prompt, neg_prompt): + + if pos_prompt is not None: + # encode the text prompt + pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() + pos_caption_enc = self.text_encoder(pos_caption_tokens)[0] + else: + pos_caption_enc = self.text_encoder(prompt_tokens)[0] + + if neg_prompt is not None: + # encode the text prompt + neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() + neg_caption_enc = self.text_encoder(neg_caption_tokens)[0] + else: + neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0] + + # degradation fourier embedding + deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi + deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) + deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) + + # degradation mlp forward + vae_de_c_embed = self.vae_de_mlp(deg_proj) + unet_de_c_embed = self.unet_de_mlp(deg_proj) + + # block embedding mlp forward + vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight) + unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight) + vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \ + vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1)) + unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \ + unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1)) + + for layer_name, module in self.vae.named_modules(): + if layer_name in self.vae_lora_layers: + split_name = layer_name.split(".") + if split_name[1] == 'down_blocks': + block_id = int(split_name[2]) + vae_embed = vae_embeds[:, block_id] + elif split_name[1] == 'mid_block': + vae_embed = vae_embeds[:, -2] + else: + vae_embed = vae_embeds[:, -1] + module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae) + + for layer_name, module in self.unet.named_modules(): + if layer_name in self.unet_lora_layers: + split_name = layer_name.split(".") + if split_name[0] == 'down_blocks': + block_id = int(split_name[1]) + unet_embed = unet_embeds[:, block_id] + elif split_name[0] == 'mid_block': + unet_embed = unet_embeds[:, 4] + elif split_name[0] == 'up_blocks': + block_id = int(split_name[1]) + 5 + unet_embed = unet_embeds[:, block_id] + else: + unet_embed = unet_embeds[:, -1] + module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet) + + encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor + pos_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=pos_caption_enc).sample + neg_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=neg_caption_enc).sample + model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred) + + x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample + output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) + + return output_image + + def save_model(self, outf): + sd = {} + sd["unet_lora_target_modules"] = self.target_modules_unet + sd["vae_lora_target_modules"] = self.target_modules_vae + sd["rank_unet"] = self.lora_rank_unet + sd["rank_vae"] = self.lora_rank_vae + sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} + sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k} + sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()} + sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()} + sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()} + sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()} + sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()} + sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()} + sd["w"] = self.W + + sd["state_embeddings"] = { + "state_dict_vae_block": self.vae_block_embeddings.state_dict(), + "state_dict_unet_block": self.unet_block_embeddings.state_dict(), + } + + torch.save(sd, outf) diff --git a/src/s3diff_tile.py b/src/s3diff_tile.py new file mode 100644 index 0000000000000000000000000000000000000000..7391d15868856c13860028f7b7a9e9e2b5036971 --- /dev/null +++ b/src/s3diff_tile.py @@ -0,0 +1,455 @@ +import os +import re +import requests +import sys +import copy +import numpy as np +from tqdm import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel +from peft import LoraConfig, get_peft_model +p = "src/" +sys.path.append(p) +from model import make_1step_sched, my_lora_fwd +from basicsr.archs.arch_util import default_init_weights +from my_utils.vaehook import VAEHook, perfcount + + +def get_layer_number(module_name): + base_layers = { + 'down_blocks': 0, + 'mid_block': 4, + 'up_blocks': 5 + } + + if module_name == 'conv_out': + return 9 + + base_layer = None + for key in base_layers: + if key in module_name: + base_layer = base_layers[key] + break + + if base_layer is None: + return None + + additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name)) + final_layer = base_layer + additional_layers + return final_layer + + +class S3Diff(torch.nn.Module): + def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64, args=None): + super().__init__() + self.args = args + self.latent_tiled_size = args.latent_tiled_size + self.latent_tiled_overlap = args.latent_tiled_overlap + + self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda() + self.sched = make_1step_sched(sd_path) + self.guidance_scale = 1.07 + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet") + + target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$" + target_modules_unet = [ + "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", + "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj" + ] + + num_embeddings = 64 + self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False) + + self.vae_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.unet_de_mlp = nn.Sequential( + nn.Linear(num_embeddings * 4, 256), + nn.ReLU(True), + ) + + self.vae_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.unet_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), + nn.ReLU(True), + ) + + self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2) + self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2) + + default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \ + self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5) + + # vae + self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim) + self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim) + + if pretrained_path is not None: + sd = torch.load(pretrained_path, map_location="cpu") + vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + _sd_vae = vae.state_dict() + for k in sd["state_dict_vae"]: + _sd_vae[k] = sd["state_dict_vae"][k] + vae.load_state_dict(_sd_vae) + + unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) + unet.add_adapter(unet_lora_config) + _sd_unet = unet.state_dict() + for k in sd["state_dict_unet"]: + _sd_unet[k] = sd["state_dict_unet"][k] + unet.load_state_dict(_sd_unet) + + _vae_de_mlp = self.vae_de_mlp.state_dict() + for k in sd["state_dict_vae_de_mlp"]: + _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k] + self.vae_de_mlp.load_state_dict(_vae_de_mlp) + + _unet_de_mlp = self.unet_de_mlp.state_dict() + for k in sd["state_dict_unet_de_mlp"]: + _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k] + self.unet_de_mlp.load_state_dict(_unet_de_mlp) + + _vae_block_mlp = self.vae_block_mlp.state_dict() + for k in sd["state_dict_vae_block_mlp"]: + _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k] + self.vae_block_mlp.load_state_dict(_vae_block_mlp) + + _unet_block_mlp = self.unet_block_mlp.state_dict() + for k in sd["state_dict_unet_block_mlp"]: + _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k] + self.unet_block_mlp.load_state_dict(_unet_block_mlp) + + _vae_fuse_mlp = self.vae_fuse_mlp.state_dict() + for k in sd["state_dict_vae_fuse_mlp"]: + _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k] + self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp) + + _unet_fuse_mlp = self.unet_fuse_mlp.state_dict() + for k in sd["state_dict_unet_fuse_mlp"]: + _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k] + self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp) + + self.W = nn.Parameter(sd["w"], requires_grad=False) + + embeddings_state_dict = sd["state_embeddings"] + self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block']) + self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block']) + else: + print("Initializing model with random weights") + vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian", + target_modules=target_modules_vae) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian", + target_modules=target_modules_unet + ) + unet.add_adapter(unet_lora_config) + + self.lora_rank_unet = lora_rank_unet + self.lora_rank_vae = lora_rank_vae + self.target_modules_vae = target_modules_vae + self.target_modules_unet = target_modules_unet + + self.vae_lora_layers = [] + for name, module in vae.named_modules(): + if 'base_layer' in name: + self.vae_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in vae.named_modules(): + if name in self.vae_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_lora_layers = [] + for name, module in unet.named_modules(): + if 'base_layer' in name: + self.unet_lora_layers.append(name[:-len(".base_layer")]) + + for name, module in unet.named_modules(): + if name in self.unet_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers} + + unet.to("cuda") + vae.to("cuda") + self.unet, self.vae = unet, vae + self.timesteps = torch.tensor([999], device="cuda").long() + self.text_encoder.requires_grad_(False) + + # vae tile + self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size) + + def set_eval(self): + self.unet.eval() + self.vae.eval() + self.vae_de_mlp.eval() + self.unet_de_mlp.eval() + self.vae_block_mlp.eval() + self.unet_block_mlp.eval() + self.vae_fuse_mlp.eval() + self.unet_fuse_mlp.eval() + + self.vae_block_embeddings.requires_grad_(False) + self.unet_block_embeddings.requires_grad_(False) + + self.unet.requires_grad_(False) + self.vae.requires_grad_(False) + + def set_train(self): + self.unet.train() + self.vae.train() + self.vae_de_mlp.train() + self.unet_de_mlp.train() + self.vae_block_mlp.train() + self.unet_block_mlp.train() + self.vae_fuse_mlp.train() + self.unet_fuse_mlp.train() + + self.vae_block_embeddings.requires_grad_(True) + self.unet_block_embeddings.requires_grad_(True) + + for n, _p in self.unet.named_parameters(): + if "lora" in n: + _p.requires_grad = True + self.unet.conv_in.requires_grad_(True) + + for n, _p in self.vae.named_parameters(): + if "lora" in n: + _p.requires_grad = True + + @perfcount + @torch.no_grad() + def forward(self, c_t, deg_score, pos_prompt, neg_prompt): + + if pos_prompt is not None: + # encode the text prompt + pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() + pos_caption_enc = self.text_encoder(pos_caption_tokens)[0] + else: + pos_caption_enc = self.text_encoder(prompt_tokens)[0] + + if neg_prompt is not None: + # encode the text prompt + neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() + neg_caption_enc = self.text_encoder(neg_caption_tokens)[0] + else: + neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0] + + # degradation fourier embedding + deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi + deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) + deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) + + # degradation mlp forward + vae_de_c_embed = self.vae_de_mlp(deg_proj) + unet_de_c_embed = self.unet_de_mlp(deg_proj) + + # block embedding mlp forward + vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight) + unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight) + + vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \ + vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1)) + unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \ + unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1)) + + for layer_name, module in self.vae.named_modules(): + if layer_name in self.vae_lora_layers: + split_name = layer_name.split(".") + if split_name[1] == 'down_blocks': + block_id = int(split_name[2]) + vae_embed = vae_embeds[:, block_id] + elif split_name[1] == 'mid_block': + vae_embed = vae_embeds[:, -2] + else: + vae_embed = vae_embeds[:, -1] + module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae) + + for layer_name, module in self.unet.named_modules(): + if layer_name in self.unet_lora_layers: + split_name = layer_name.split(".") + if split_name[0] == 'down_blocks': + block_id = int(split_name[1]) + unet_embed = unet_embeds[:, block_id] + elif split_name[0] == 'mid_block': + unet_embed = unet_embeds[:, 4] + elif split_name[0] == 'up_blocks': + block_id = int(split_name[1]) + 5 + unet_embed = unet_embeds[:, block_id] + else: + unet_embed = unet_embeds[:, -1] + module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet) + + lq_latent = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor + + ## add tile function + _, _, h, w = lq_latent.size() + tile_size, tile_overlap = (self.latent_tiled_size, self.latent_tiled_overlap) + if h * w <= tile_size * tile_size: + print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.") + pos_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=pos_caption_enc).sample + neg_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=neg_caption_enc).sample + model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred) + else: + print(f"[Tiled Latent]: the input size is {c_t.shape[-2]}x{c_t.shape[-1]}, need to tiled") + # tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to() + tile_size = min(tile_size, min(h, w)) + tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to(c_t.device) + + grid_rows = 0 + cur_x = 0 + while cur_x < lq_latent.size(-1): + cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size + grid_rows += 1 + + grid_cols = 0 + cur_y = 0 + while cur_y < lq_latent.size(-2): + cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size + grid_cols += 1 + + input_list = [] + noise_preds = [] + for row in range(grid_rows): + noise_preds_row = [] + for col in range(grid_cols): + if col < grid_cols-1 or row < grid_rows-1: + # extract tile from input image + ofs_x = max(row * tile_size-tile_overlap * row, 0) + ofs_y = max(col * tile_size-tile_overlap * col, 0) + # input tile area on total image + if row == grid_rows-1: + ofs_x = w - tile_size + if col == grid_cols-1: + ofs_y = h - tile_size + + input_start_x = ofs_x + input_end_x = ofs_x + tile_size + input_start_y = ofs_y + input_end_y = ofs_y + tile_size + + # input tile dimensions + input_tile = lq_latent[:, :, input_start_y:input_end_y, input_start_x:input_end_x] + input_list.append(input_tile) + + if len(input_list) == 1 or col == grid_cols-1: + input_list_t = torch.cat(input_list, dim=0) + # predict the noise residual + pos_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=pos_caption_enc).sample + neg_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=neg_caption_enc).sample + model_out = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred) + input_list = [] + noise_preds.append(model_out) + + # Stitch noise predictions for all tiles + noise_pred = torch.zeros(lq_latent.shape, device=lq_latent.device) + contributors = torch.zeros(lq_latent.shape, device=lq_latent.device) + # Add each tile contribution to overall latents + for row in range(grid_rows): + for col in range(grid_cols): + if col < grid_cols-1 or row < grid_rows-1: + # extract tile from input image + ofs_x = max(row * tile_size-tile_overlap * row, 0) + ofs_y = max(col * tile_size-tile_overlap * col, 0) + # input tile area on total image + if row == grid_rows-1: + ofs_x = w - tile_size + if col == grid_cols-1: + ofs_y = h - tile_size + + input_start_x = ofs_x + input_end_x = ofs_x + tile_size + input_start_y = ofs_y + input_end_y = ofs_y + tile_size + + noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights + contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights + # Average overlapping areas with more than 1 contributor + noise_pred /= contributors + model_pred = noise_pred + + x_denoised = self.sched.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample + output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) + + return output_image + + def save_model(self, outf): + sd = {} + sd["unet_lora_target_modules"] = self.target_modules_unet + sd["vae_lora_target_modules"] = self.target_modules_vae + sd["rank_unet"] = self.lora_rank_unet + sd["rank_vae"] = self.lora_rank_vae + sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} + sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k} + sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()} + sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()} + sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()} + sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()} + sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()} + sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()} + sd["w"] = self.W + + sd["state_embeddings"] = { + "state_dict_vae_block": self.vae_block_embeddings.state_dict(), + "state_dict_unet_block": self.unet_block_embeddings.state_dict(), + } + + torch.save(sd, outf) + + def _set_latent_tile(self, + latent_tiled_size = 96, + latent_tiled_overlap = 32): + self.latent_tiled_size = latent_tiled_size + self.latent_tiled_overlap = latent_tiled_overlap + + def _init_tiled_vae(self, + encoder_tile_size = 256, + decoder_tile_size = 256, + fast_decoder = False, + fast_encoder = False, + color_fix = False, + vae_to_gpu = True): + # save original forward (only once) + if not hasattr(self.vae.encoder, 'original_forward'): + setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward) + if not hasattr(self.vae.decoder, 'original_forward'): + setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward) + + encoder = self.vae.encoder + decoder = self.vae.decoder + + self.vae.encoder.forward = VAEHook( + encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) + self.vae.decoder.forward = VAEHook( + decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) + + def _gaussian_weights(self, tile_width, tile_height, nbatches): + """Generates a gaussian mask of weights for tile contributions""" + from numpy import pi, exp, sqrt + import numpy as np + + latent_width = tile_width + latent_height = tile_height + + var = 0.01 + midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)] + midpoint = latent_height / 2 + y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)] + + weights = np.outer(y_probs, x_probs) + return torch.tile(torch.tensor(weights), (nbatches, self.unet.config.in_channels, 1, 1)) + diff --git a/src/train_s3diff.py b/src/train_s3diff.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5b2129b1fa5171de529566f78b6d80ff2d1580 --- /dev/null +++ b/src/train_s3diff.py @@ -0,0 +1,284 @@ +import os +os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' + +import gc +import lpips +import clip +import random +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers + +from omegaconf import OmegaConf +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers.utils.import_utils import is_xformers_available +from diffusers.optimization import get_scheduler + +from de_net import DEResNet +from s3diff import S3Diff +from my_utils.training_utils import parse_args_paired_training, PairedDataset, degradation_proc + +def main(args): + + # init and save configs + config = OmegaConf.load(args.base_config) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + ) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) + + # initialize degradation estimation network + net_de = DEResNet(num_in_ch=3, num_degradation=2) + net_de.load_model(args.de_net_path) + net_de = net_de.cuda() + net_de.eval() + + # initialize net_sr + net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path) + net_sr.set_train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + net_sr.unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available, please install it by running `pip install xformers`") + + if args.gradient_checkpointing: + net_sr.unet.enable_gradient_checkpointing() + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gan_disc_type == "vagan": + import vision_aided_loss + net_disc = vision_aided_loss.Discriminator(cv_type='dino', output_type='conv_multi_level', loss_type=args.gan_loss_type, device="cuda") + else: + raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented") + + net_disc = net_disc.cuda() + net_disc.requires_grad_(True) + net_disc.cv_ensemble.requires_grad_(False) + net_disc.train() + + net_lpips = lpips.LPIPS(net='vgg').cuda() + net_lpips.requires_grad_(False) + + # make the optimizer + layers_to_opt = [] + layers_to_opt = layers_to_opt + list(net_sr.vae_block_embeddings.parameters()) + list(net_sr.unet_block_embeddings.parameters()) + layers_to_opt = layers_to_opt + list(net_sr.vae_de_mlp.parameters()) + list(net_sr.unet_de_mlp.parameters()) + \ + list(net_sr.vae_block_mlp.parameters()) + list(net_sr.unet_block_mlp.parameters()) + \ + list(net_sr.vae_fuse_mlp.parameters()) + list(net_sr.unet_fuse_mlp.parameters()) + + for n, _p in net_sr.unet.named_parameters(): + if "lora" in n: + assert _p.requires_grad + layers_to_opt.append(_p) + layers_to_opt += list(net_sr.unet.conv_in.parameters()) + + for n, _p in net_sr.vae.named_parameters(): + if "lora" in n: + assert _p.requires_grad + layers_to_opt.append(_p) + + dataset_train = PairedDataset(config.train) + dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) + dataset_val = PairedDataset(config.validation) + dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) + + + optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon,) + lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, power=args.lr_power,) + + optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon,) + lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, power=args.lr_power) + + # Prepare everything with our `accelerator`. + net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare( + net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc + ) + net_de, net_lpips = accelerator.prepare(net_de, net_lpips) + # # renorm with image net statistics + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move al networksr to device and cast to weight_dtype + net_sr.to(accelerator.device, dtype=weight_dtype) + net_de.to(accelerator.device, dtype=weight_dtype) + net_disc.to(accelerator.device, dtype=weight_dtype) + net_lpips.to(accelerator.device, dtype=weight_dtype) + + progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", + disable=not accelerator.is_local_main_process,) + + for name, module in net_disc.named_modules(): + if "attn" in name: + module.fused_attn = False + + # start the training loop + global_step = 0 + for epoch in range(0, args.num_training_epochs): + for step, batch in enumerate(dl_train): + l_acc = [net_sr, net_disc] + with accelerator.accumulate(*l_acc): + x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch, accelerator.device) + B, C, H, W = x_src.shape + with torch.no_grad(): + deg_score = net_de(x_ori_size_src.detach()).detach() + + pos_tag_prompt = [args.pos_prompt for _ in range(B)] + neg_tag_prompt = [args.neg_prompt for _ in range(B)] + + neg_probs = torch.rand(B).to(accelerator.device) + + # build mixed prompt and target + mixed_tag_prompt = [_neg_tag if p_i < args.neg_prob else _pos_tag for _neg_tag, _pos_tag, p_i in zip(neg_tag_prompt, pos_tag_prompt, neg_probs)] + neg_probs = neg_probs.reshape(B, 1, 1, 1) + mixed_tgt = torch.where(neg_probs < args.neg_prob, x_src, x_tgt) + + x_tgt_pred = net_sr(x_src.detach(), deg_score, mixed_tag_prompt) + loss_l2 = F.mse_loss(x_tgt_pred.float(), mixed_tgt.detach().float(), reduction="mean") * args.lambda_l2 + loss_lpips = net_lpips(x_tgt_pred.float(), mixed_tgt.detach().float()).mean() * args.lambda_lpips + + loss = loss_l2 + loss_lpips + + accelerator.backward(loss, retain_graph=False) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + """ + Generator loss: fool the discriminator + """ + x_tgt_pred = net_sr(x_src.detach(), deg_score, pos_tag_prompt) + lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan + accelerator.backward(lossG) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + """ + Discriminator loss: fake image vs real image + """ + # real image + lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan + accelerator.backward(lossD_real.mean()) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm) + optimizer_disc.step() + lr_scheduler_disc.step() + optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) + # fake image + lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan + accelerator.backward(lossD_fake.mean()) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm) + optimizer_disc.step() + optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) + lossD = lossD_real + lossD_fake + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + logs = {} + logs["lossG"] = lossG.detach().item() + logs["lossD"] = lossD.detach().item() + logs["loss_l2"] = loss_l2.detach().item() + logs["loss_lpips"] = loss_lpips.detach().item() + progress_bar.set_postfix(**logs) + + # checkpoint the model + if global_step % args.checkpointing_steps == 1: + outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") + accelerator.unwrap_model(net_sr).save_model(outf) + + # compute validation set FID, L2, LPIPS, CLIP-SIM + if global_step % args.eval_freq == 1: + l_l2, l_lpips = [], [] + + val_count = 0 + for step, batch_val in enumerate(dl_val): + if step >= args.num_samples_eval: + break + x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch_val, accelerator.device) + B, C, H, W = x_src.shape + assert B == 1, "Use batch size 1 for eval." + with torch.no_grad(): + # forward pass + with torch.no_grad(): + deg_score = net_de(x_ori_size_src.detach()) + + pos_tag_prompt = [args.pos_prompt for _ in range(B)] + x_tgt_pred = accelerator.unwrap_model(net_sr)(x_src.detach(), deg_score, pos_tag_prompt) + # compute the reconstruction losses + loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.detach().float(), reduction="mean") + loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.detach().float()).mean() + + l_l2.append(loss_l2.item()) + l_lpips.append(loss_lpips.item()) + + if args.save_val and val_count < 5: + x_src = x_src.cpu().detach() * 0.5 + 0.5 + x_tgt = x_tgt.cpu().detach() * 0.5 + 0.5 + x_tgt_pred = x_tgt_pred.cpu().detach() * 0.5 + 0.5 + + combined = torch.cat([x_src, x_tgt_pred, x_tgt], dim=3) + output_pil = transforms.ToPILImage()(combined[0]) + outf = os.path.join(args.output_dir, f"val_{step}.png") + output_pil.save(outf) + val_count += 1 + + logs["val/l2"] = np.mean(l_l2) + logs["val/lpips"] = np.mean(l_lpips) + gc.collect() + torch.cuda.empty_cache() + accelerator.log(logs, step=global_step) + + +if __name__ == "__main__": + args = parse_args_paired_training() + main(args) diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..999d2f5205b22dda9ac8acc3e738fa0e1099a426 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path +import json +import subprocess + +import torch +import torch.distributed as dist + +from typing import List, Dict, Tuple, Optional +from torch import Tensor + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + elif 'SLURM_PROCID' in os.environ: + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29200') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + args.dist_url = 'env://' + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + +def clip_grad_norm_( + parameters, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: + r"""Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + first_device = grads[0].device + grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \ + = {(first_device, grads[0].dtype): [[g.detach() for g in grads]]} + + norms = [torch.norm(g) for g in grads] + total_norm = torch.norm(torch.stack(norms)) + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for ((device, _), [grads]) in grouped_grads.items(): + if (foreach is None or foreach): + torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload] + elif foreach: + raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in grads: + g.detach().mul_(clip_coef_clamped_device) + + return total_norm + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + + # checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + checkpoint_paths = [output_dir / 'checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + +def load_model(args, model_without_ddp, optimizer): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + print("With optim & sched!") + +def auto_load_model(args, model, model_without_ddp, optimizer): + output_dir = Path(args.output_dir) + + # torch.amp + if args.auto_resume and len(args.resume) == 0: + import glob + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('-')[-1].split('.')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) + print("Auto resume checkpoint: %s" % args.resume) + + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +def create_ds_config(args): + args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") + with open(args.deepspeed_config, mode="w") as writer: + ds_config = { + "train_batch_size": args.batch_size * args.accum_iter * get_world_size(), + "train_micro_batch_size_per_gpu": args.batch_size, + "steps_per_print": 1000, + "optimizer": { + "type": "Adam", + "adam_w_mode": True, + "params": { + "lr": args.lr, + "weight_decay": args.weight_decay, + "bias_correction": True, + "betas": [ + args.opt_betas[0], + args.opt_betas[1] + ], + "eps": args.opt_eps + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + # "bf16": { + # "enabled": True + # }, + "amp": { + "enabled": False, + "opt_level": "O2" + }, + "flops_profiler": { + "enabled": True, + "profile_step": -1, + "module_depth": -1, + "top_modules": 1, + "detailed": True, + }, + } + + if args.clip_grad is not None: + ds_config.update({'gradient_clipping': args.clip_grad}) + + if args.zero_stage == 1: + ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}}) + elif args.zero_stage > 1: + raise NotImplementedError() + + writer.write(json.dumps(ds_config, indent=2)) + +def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): + parameter_group_names = {} + parameter_group_vars = {} + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = "no_decay" + this_weight_decay = 0. + else: + group_name = "decay" + this_weight_decay = weight_decay + if get_num_layer is not None: + layer_id = get_num_layer(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if get_layer_scale is not None: + scale = get_layer_scale(layer_id) + else: + scale = 1. + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()) + diff --git a/utils/util_image.py b/utils/util_image.py new file mode 100644 index 0000000000000000000000000000000000000000..88d6f307b208ead04389c870d1b24a82c5c0960e --- /dev/null +++ b/utils/util_image.py @@ -0,0 +1,935 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Power by Zongsheng Yue 2021-11-24 16:54:19 + +import sys +import cv2 +import math +import torch +import random +import numpy as np +from scipy import fft +from pathlib import Path +from einops import rearrange +from skimage import img_as_ubyte, img_as_float32 + +# --------------------------Metrics---------------------------- +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + +def calculate_ssim(im1, im2, border=0, ycbcr=False): + ''' + SSIM the same outputs as MATLAB's + im1, im2: h x w x , [0, 255], uint8 + ''' + if not im1.shape == im2.shape: + raise ValueError('Input images must have the same dimensions.') + + if ycbcr: + im1 = rgb2ycbcr(im1, True) + im2 = rgb2ycbcr(im2, True) + + h, w = im1.shape[:2] + im1 = im1[border:h-border, border:w-border] + im2 = im2[border:h-border, border:w-border] + + if im1.ndim == 2: + return ssim(im1, im2) + elif im1.ndim == 3: + if im1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(im1[:,:,i], im2[:,:,i])) + return np.array(ssims).mean() + elif im1.shape[2] == 1: + return ssim(np.squeeze(im1), np.squeeze(im2)) + else: + raise ValueError('Wrong input image dimensions.') + +def calculate_psnr(im1, im2, border=0, ycbcr=False): + ''' + PSNR metric. + im1, im2: h x w x , [0, 255], uint8 + ''' + if not im1.shape == im2.shape: + raise ValueError('Input images must have the same dimensions.') + + if ycbcr: + im1 = rgb2ycbcr(im1, True) + im2 = rgb2ycbcr(im2, True) + + h, w = im1.shape[:2] + im1 = im1[border:h-border, border:w-border] + im2 = im2[border:h-border, border:w-border] + + im1 = im1.astype(np.float64) + im2 = im2.astype(np.float64) + mse = np.mean((im1 - im2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + +def batch_PSNR(img, imclean, border=0, ycbcr=False): + if ycbcr: + img = rgb2ycbcrTorch(img, True) + imclean = rgb2ycbcrTorch(imclean, True) + Img = img.data.cpu().numpy() + Iclean = imclean.data.cpu().numpy() + Img = img_as_ubyte(Img) + Iclean = img_as_ubyte(Iclean) + PSNR = 0 + h, w = Iclean.shape[2:] + for i in range(Img.shape[0]): + PSNR += calculate_psnr(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border) + return PSNR + +def batch_SSIM(img, imclean, border=0, ycbcr=False): + if ycbcr: + img = rgb2ycbcrTorch(img, True) + imclean = rgb2ycbcrTorch(imclean, True) + Img = img.data.cpu().numpy() + Iclean = imclean.data.cpu().numpy() + Img = img_as_ubyte(Img) + Iclean = img_as_ubyte(Iclean) + SSIM = 0 + for i in range(Img.shape[0]): + SSIM += calculate_ssim(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border) + return SSIM + +def normalize_np(im, mean=0.5, std=0.5, reverse=False): + ''' + Input: + im: h x w x c, numpy array + Normalize: (im - mean) / std + Reverse: im * std + mean + + ''' + if not isinstance(mean, (list, tuple)): + mean = [mean, ] * im.shape[2] + mean = np.array(mean).reshape([1, 1, im.shape[2]]) + + if not isinstance(std, (list, tuple)): + std = [std, ] * im.shape[2] + std = np.array(std).reshape([1, 1, im.shape[2]]) + + if not reverse: + out = (im.astype(np.float32) - mean) / std + else: + out = im.astype(np.float32) * std + mean + return out + +def normalize_th(im, mean=0.5, std=0.5, reverse=False): + ''' + Input: + im: b x c x h x w, torch tensor + Normalize: (im - mean) / std + Reverse: im * std + mean + + ''' + if not isinstance(mean, (list, tuple)): + mean = [mean, ] * im.shape[1] + mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1]) + + if not isinstance(std, (list, tuple)): + std = [std, ] * im.shape[1] + std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1]) + + if not reverse: + out = (im - mean) / std + else: + out = im * std + mean + return out + +# ------------------------Image format-------------------------- +def rgb2ycbcr(im, only_y=True): + ''' + same as matlab rgb2ycbcr + Input: + im: uint8 [0,255] or float [0,1] + only_y: only return Y channel + ''' + # transform to float64 data type, range [0, 255] + if im.dtype == np.uint8: + im_temp = im.astype(np.float64) + else: + im_temp = (im * 255).astype(np.float64) + + # convert + if only_y: + rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0 + else: + rlt = np.matmul(im_temp, np.array([[65.481, -37.797, 112.0 ], + [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]])/255.0) + [16, 128, 128] + if im.dtype == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(im.dtype) + +def rgb2ycbcrTorch(im, only_y=True): + ''' + same as matlab rgb2ycbcr + Input: + im: float [0,1], N x 3 x H x W + only_y: only return Y channel + ''' + # transform to range [0,255.0] + im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C + # convert + if only_y: + rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966], + device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0 + else: + rlt = torch.matmul(im_temp, torch.tensor([[65.481, -37.797, 112.0 ], + [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]], + device=im.device, dtype=im.dtype)/255.0) + \ + torch.tensor([16, 128, 128]).view([-1, 1, 1, 3]) + rlt /= 255.0 + rlt.clamp_(0.0, 1.0) + return rlt.permute([0, 3, 1, 2]) + +def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + +def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + flag_tensor = torch.is_tensor(tensor) + if flag_tensor: + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1 and flag_tensor: + result = result[0] + return result + +def img2tensor(imgs, out_type=torch.float32): + """Convert image numpy arrays into torch tensor. + Args: + imgs (Array or list[array]): Accept shapes: + 3) list of numpy arrays + 1) 3D numpy array of shape (H x W x 3/1); + 2) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + + Returns: + (array or list): 4D ndarray of shape (1 x C x H x W) + """ + + def _img2tensor(img): + if img.ndim == 2: + tensor = torch.from_numpy(img[None, None,]).type(out_type) + elif img.ndim == 3: + tensor = torch.from_numpy(rearrange(img, 'h w c -> c h w')).type(out_type).unsqueeze(0) + else: + raise TypeError(f'2D or 3D numpy array expected, got{img.ndim}D array') + return tensor + + if not (isinstance(imgs, np.ndarray) or (isinstance(imgs, list) and all(isinstance(t, np.ndarray) for t in imgs))): + raise TypeError(f'Numpy array or list of numpy array expected, got {type(imgs)}') + + flag_numpy = isinstance(imgs, np.ndarray) + if flag_numpy: + imgs = [imgs,] + result = [] + for _img in imgs: + result.append(_img2tensor(_img)) + + if len(result) == 1 and flag_numpy: + result = result[0] + return result + +# ------------------------Image resize----------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + +# ------------------------Image I/O----------------------------- +def imread(path, chn='rgb', dtype='float32'): + ''' + Read image. + chn: 'rgb', 'bgr' or 'gray' + out: + im: h x w x c, numpy tensor + ''' + im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # BGR, uint8 + try: + if chn.lower() == 'rgb': + if im.ndim == 3: + im = bgr2rgb(im) + else: + im = np.stack((im, im, im), axis=2) + elif chn.lower() == 'gray': + assert im.ndim == 2 + except: + print(str(path)) + + if dtype == 'float32': + im = im.astype(np.float32) / 255. + elif dtype == 'float64': + im = im.astype(np.float64) / 255. + elif dtype == 'uint8': + pass + else: + sys.exit('Please input corrected dtype: float32, float64 or uint8!') + + return im + +def imwrite(im_in, path, chn='rgb', dtype_in='float32', qf=None): + ''' + Save image. + Input: + im: h x w x c, numpy tensor + path: the saving path + chn: the channel order of the im, + ''' + im = im_in.copy() + if isinstance(path, str): + path = Path(path) + if dtype_in != 'uint8': + im = img_as_ubyte(im) + + if chn.lower() == 'rgb' and im.ndim == 3: + im = rgb2bgr(im) + + if qf is not None and path.suffix.lower() in ['.jpg', '.jpeg']: + flag = cv2.imwrite(str(path), im, [int(cv2.IMWRITE_JPEG_QUALITY), int(qf)]) + else: + flag = cv2.imwrite(str(path), im) + + return flag + +def jpeg_compress(im, qf, chn_in='rgb'): + ''' + Input: + im: h x w x 3 array + qf: compress factor, (0, 100] + chn_in: 'rgb' or 'bgr' + Return: + Compressed Image with channel order: chn_in + ''' + # transform to BGR channle and uint8 data type + im_bgr = rgb2bgr(im) if chn_in.lower() == 'rgb' else im + if im.dtype != np.dtype('uint8'): im_bgr = img_as_ubyte(im_bgr) + + # JPEG compress + flag, encimg = cv2.imencode('.jpg', im_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), qf]) + assert flag + im_jpg_bgr = cv2.imdecode(encimg, 1) # uint8, BGR + + # transform back to original channel and the original data type + im_out = bgr2rgb(im_jpg_bgr) if chn_in.lower() == 'rgb' else im_jpg_bgr + if im.dtype != np.dtype('uint8'): im_out = img_as_float32(im_out).astype(im.dtype) + return im_out + +# ------------------------Augmentation----------------------------- +def data_aug_np(image, mode): + ''' + Performs data augmentation of the input image + Input: + image: a cv2 (OpenCV) image + mode: int. Choice of transformation to apply to the image + 0 - no transformation + 1 - flip up and down + 2 - rotate counterwise 90 degree + 3 - rotate 90 degree and flip up and down + 4 - rotate 180 degree + 5 - rotate 180 degree and flip + 6 - rotate 270 degree + 7 - rotate 270 degree and flip + ''' + if mode == 0: + # original + out = image + elif mode == 1: + # flip up and down + out = np.flipud(image) + elif mode == 2: + # rotate counterwise 90 degree + out = np.rot90(image) + elif mode == 3: + # rotate 90 degree and flip up and down + out = np.rot90(image) + out = np.flipud(out) + elif mode == 4: + # rotate 180 degree + out = np.rot90(image, k=2) + elif mode == 5: + # rotate 180 degree and flip + out = np.rot90(image, k=2) + out = np.flipud(out) + elif mode == 6: + # rotate 270 degree + out = np.rot90(image, k=3) + elif mode == 7: + # rotate 270 degree and flip + out = np.rot90(image, k=3) + out = np.flipud(out) + else: + raise Exception('Invalid choice of image transformation') + + return out.copy() + +def inverse_data_aug_np(image, mode): + ''' + Performs inverse data augmentation of the input image + ''' + if mode == 0: + # original + out = image + elif mode == 1: + out = np.flipud(image) + elif mode == 2: + out = np.rot90(image, axes=(1,0)) + elif mode == 3: + out = np.flipud(image) + out = np.rot90(out, axes=(1,0)) + elif mode == 4: + out = np.rot90(image, k=2, axes=(1,0)) + elif mode == 5: + out = np.flipud(image) + out = np.rot90(out, k=2, axes=(1,0)) + elif mode == 6: + out = np.rot90(image, k=3, axes=(1,0)) + elif mode == 7: + # rotate 270 degree and flip + out = np.flipud(image) + out = np.rot90(out, k=3, axes=(1,0)) + else: + raise Exception('Invalid choice of image transformation') + + return out + +class SpatialAug: + def __init__(self): + pass + + def __call__(self, im, flag=None): + if flag is None: + flag = random.randint(0, 7) + + out = data_aug_np(im, flag) + return out + +# ----------------------Visualization---------------------------- +def imshow(x, title=None, cbar=False): + import matplotlib.pyplot as plt + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + +# -----------------------Covolution------------------------------ +def imgrad(im, pading_mode='mirror'): + ''' + Calculate image gradient. + Input: + im: h x w x c numpy array + ''' + from scipy.ndimage import correlate # lazy import + wx = np.array([[0, 0, 0], + [-1, 1, 0], + [0, 0, 0]], dtype=np.float32) + wy = np.array([[0, -1, 0], + [0, 1, 0], + [0, 0, 0]], dtype=np.float32) + if im.ndim == 3: + gradx = np.stack( + [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])], + axis=2 + ) + grady = np.stack( + [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])], + axis=2 + ) + grad = np.concatenate((gradx, grady), axis=2) + else: + gradx = correlate(im, wx, mode=pading_mode) + grady = correlate(im, wy, mode=pading_mode) + grad = np.stack((gradx, grady), axis=2) + + return {'gradx': gradx, 'grady': grady, 'grad':grad} + +def imgrad_fft(im): + ''' + Calculate image gradient. + Input: + im: h x w x c numpy array + ''' + wx = np.rot90(np.array([[0, 0, 0], + [-1, 1, 0], + [0, 0, 0]], dtype=np.float32), k=2) + gradx = convfft(im, wx) + wy = np.rot90(np.array([[0, -1, 0], + [0, 1, 0], + [0, 0, 0]], dtype=np.float32), k=2) + grady = convfft(im, wy) + grad = np.concatenate((gradx, grady), axis=2) + + return {'gradx': gradx, 'grady': grady, 'grad':grad} + +def convfft(im, weight): + ''' + Convolution with FFT + Input: + im: h1 x w1 x c numpy array + weight: h2 x w2 numpy array + Output: + out: h1 x w1 x c numpy array + ''' + axes = (0,1) + otf = psf2otf(weight, im.shape[:2]) + if im.ndim == 3: + otf = np.tile(otf[:, :, None], (1,1,im.shape[2])) + out = fft.ifft2(fft.fft2(im, axes=axes) * otf, axes=axes).real + return out + +def psf2otf(psf, shape): + """ + MATLAB psf2otf function. + Borrowed from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py. + Input: + psf : h x w numpy array + shape : list or tuple, output shape of the OTF array + Output: + otf : OTF array with the desirable shape + """ + if np.all(psf == 0): + return np.zeros_like(psf) + + inshape = psf.shape + # Pad the PSF to outsize + psf = zero_pad(psf, shape, position='corner') + + # Circularly shift OTF so that the 'center' of the PSF is [0,0] element of the array + for axis, axis_size in enumerate(inshape): + psf = np.roll(psf, -int(axis_size / 2), axis=axis) + + # Compute the OTF + otf = fft.fft2(psf) + + # Estimate the rough number of operations involved in the FFT + # and discard the PSF imaginary part if within roundoff error + # roundoff error = machine epsilon = sys.float_info.epsilon + # or np.finfo().eps + n_ops = np.sum(psf.size * np.log2(psf.shape)) + otf = np.real_if_close(otf, tol=n_ops) + + return otf + +# ----------------------Patch Cropping---------------------------- +def random_crop(im, pch_size): + ''' + Randomly crop a patch from the give image. + ''' + h, w = im.shape[:2] + if h == pch_size and w == pch_size: + im_pch = im + else: + assert h >= pch_size or w >= pch_size + ind_h = random.randint(0, h-pch_size) + ind_w = random.randint(0, w-pch_size) + im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,] + + return im_pch + +class RandomCrop: + def __init__(self, pch_size): + self.pch_size = pch_size + + def __call__(self, im): + return random_crop(im, self.pch_size) + +class ImageSpliterNp: + def __init__(self, im, pch_size, stride, sf=1): + ''' + Input: + im: h x w x c, numpy array, [0, 1], low-resolution image in SR + pch_size, stride: patch setting + sf: scale factor in image super-resolution + ''' + assert stride <= pch_size + self.stride = stride + self.pch_size = pch_size + self.sf = sf + + if im.ndim == 2: + im = im[:, :, None] + + height, width, chn = im.shape + self.height_starts_list = self.extract_starts(height) + self.width_starts_list = self.extract_starts(width) + self.length = self.__len__() + self.num_pchs = 0 + + self.im_ori = im + self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype) + self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype) + + def extract_starts(self, length): + starts = list(range(0, length, self.stride)) + if starts[-1] + self.pch_size > length: + starts[-1] = length - self.pch_size + return starts + + def __len__(self): + return len(self.height_starts_list) * len(self.width_starts_list) + + def __iter__(self): + return self + + def __next__(self): + if self.num_pchs < self.length: + w_start_idx = self.num_pchs // len(self.height_starts_list) + w_start = self.width_starts_list[w_start_idx] * self.sf + w_end = w_start + self.pch_size * self.sf + + h_start_idx = self.num_pchs % len(self.height_starts_list) + h_start = self.height_starts_list[h_start_idx] * self.sf + h_end = h_start + self.pch_size * self.sf + + pch = self.im_ori[h_start:h_end, w_start:w_end,] + self.w_start, self.w_end = w_start, w_end + self.h_start, self.h_end = h_start, h_end + + self.num_pchs += 1 + else: + raise StopIteration(0) + + return pch, (h_start, h_end, w_start, w_end) + + def update(self, pch_res, index_infos): + ''' + Input: + pch_res: pch_size x pch_size x 3, [0,1] + index_infos: (h_start, h_end, w_start, w_end) + ''' + if index_infos is None: + w_start, w_end = self.w_start, self.w_end + h_start, h_end = self.h_start, self.h_end + else: + h_start, h_end, w_start, w_end = index_infos + + self.im_res[h_start:h_end, w_start:w_end] += pch_res + self.pixel_count[h_start:h_end, w_start:w_end] += 1 + + def gather(self): + assert np.all(self.pixel_count != 0) + return self.im_res / self.pixel_count + +class ImageSpliterTh: + def __init__(self, im, pch_size, stride, sf=1, extra_bs=1): + ''' + Input: + im: n x c x h x w, torch tensor, float, low-resolution image in SR + pch_size, stride: patch setting + sf: scale factor in image super-resolution + pch_bs: aggregate pchs to processing, only used when inputing single image + ''' + assert stride <= pch_size + self.stride = stride + self.pch_size = pch_size + self.sf = sf + self.extra_bs = extra_bs + + bs, chn, height, width= im.shape + self.true_bs = bs + + self.height_starts_list = self.extract_starts(height) + self.width_starts_list = self.extract_starts(width) + self.starts_list = [] + for ii in self.height_starts_list: + for jj in self.width_starts_list: + self.starts_list.append([ii, jj]) + + self.length = self.__len__() + self.count_pchs = 0 + + self.im_ori = im + self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device) + self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device) + + def extract_starts(self, length): + if length <= self.pch_size: + starts = [0,] + else: + starts = list(range(0, length, self.stride)) + for ii in range(len(starts)): + if starts[ii] + self.pch_size > length: + starts[ii] = length - self.pch_size + starts = sorted(set(starts), key=starts.index) + return starts + + def __len__(self): + return len(self.height_starts_list) * len(self.width_starts_list) + + def __iter__(self): + return self + + def __next__(self): + if self.count_pchs < self.length: + index_infos = [] + current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs] + for ii, (h_start, w_start) in enumerate(current_starts_list): + w_end = w_start + self.pch_size + h_end = h_start + self.pch_size + current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end] + if ii == 0: + pch = current_pch + else: + pch = torch.cat([pch, current_pch], dim=0) + + h_start *= self.sf + h_end *= self.sf + w_start *= self.sf + w_end *= self.sf + index_infos.append([h_start, h_end, w_start, w_end]) + + self.count_pchs += len(current_starts_list) + else: + raise StopIteration() + + return pch, index_infos + + def update(self, pch_res, index_infos): + ''' + Input: + pch_res: (n*extra_bs) x c x pch_size x pch_size, float + index_infos: [(h_start, h_end, w_start, w_end),] + ''' + assert pch_res.shape[0] % self.true_bs == 0 + pch_list = torch.split(pch_res, self.true_bs, dim=0) + assert len(pch_list) == len(index_infos) + for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos): + current_pch = pch_list[ii] + self.im_res[:, :, h_start:h_end, w_start:w_end] += current_pch + self.pixel_count[:, :, h_start:h_end, w_start:w_end] += 1 + + def gather(self): + assert torch.all(self.pixel_count != 0) + return self.im_res.div(self.pixel_count) + +# ----------------------Patch Cropping---------------------------- +class Clamper: + def __init__(self, min_max=(-1, 1)): + self.min_bound, self.max_bound = min_max[0], min_max[1] + + def __call__(self, im): + if isinstance(im, np.ndarray): + return np.clip(im, a_min=self.min_bound, a_max=self.max_bound) + elif isinstance(im, torch.Tensor): + return torch.clamp(im, min=self.min_bound, max=self.max_bound) + else: + raise TypeError(f'ndarray or Tensor expected, got {type(im)}') + +if __name__ == '__main__': + im = np.random.randn(64, 64, 3).astype(np.float32) + + grad1 = imgrad(im)['grad'] + grad2 = imgrad_fft(im)['grad'] + + error = np.abs(grad1 -grad2).max() + mean_error = np.abs(grad1 -grad2).mean() + print('The largest error is {:.2e}'.format(error)) + print('The mean error is {:.2e}'.format(mean_error)) \ No newline at end of file diff --git a/utils/wavelet_color.py b/utils/wavelet_color.py new file mode 100644 index 0000000000000000000000000000000000000000..0947a8621ecea7ef3b96d8b56acbdaecd3821a3d --- /dev/null +++ b/utils/wavelet_color.py @@ -0,0 +1,119 @@ +''' +# -------------------------------------------------------------------------------- +# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) +# -------------------------------------------------------------------------------- +''' + +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F + +from torchvision.transforms import ToTensor, ToPILImage + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.reshape(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().reshape(b, c, 1, 1) + feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq \ No newline at end of file