PAN / data /common.py
Tellll's picture
Update code and model to support NHWC input format
90e4acb
raw
history blame
904 Bytes
import random
import numpy as np
import skimage.color as sc
import torch
def set_channel(*args, n_channels=3):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255, format='NCHW'):
def _np2Tensor(img, channel_format):
assert channel_format in ('NCHW', 'NHWC')
img = np.ascontiguousarray(img.transpose((2, 0, 1))) if channel_format == ('NCHW') else img
tensor = torch.from_numpy(img).float()
tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(a, format) for a in args]