#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
############################################################# | |
# File: pixelshuffle.py | |
# Created Date: Friday July 1st 2022 | |
# Author: Chen Xuanhong | |
# Email: chenxuanhongzju@outlook.com | |
# Last Modified: Friday, 1st July 2022 10:18:39 am | |
# Modified By: Chen Xuanhong | |
# Copyright (c) 2022 Shanghai Jiao Tong University | |
############################################################# | |
import torch.nn as nn | |
def pixelshuffle_block( | |
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False | |
): | |
""" | |
Upsample features according to `upscale_factor`. | |
""" | |
padding = kernel_size // 2 | |
conv = nn.Conv2d( | |
in_channels, | |
out_channels * (upscale_factor**2), | |
kernel_size, | |
padding=1, | |
bias=bias, | |
) | |
pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
return nn.Sequential(*[conv, pixel_shuffle]) | |