Spaces:
Running
Running
from __future__ import print_function, division | |
import os, random, time | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset | |
from torchvision import transforms, utils | |
import rawpy | |
from glob import glob | |
from PIL import Image as PILImage | |
import numbers | |
from scipy.misc import imread | |
from .base_dataset import BaseDataset | |
class FiveKDatasetTrain(BaseDataset): | |
def __init__(self, opt): | |
super().__init__(opt=opt) | |
self.patch_size = 256 | |
input_RAWs_WBs, target_RGBs = self.load(is_train=True) | |
assert len(input_RAWs_WBs) == len(target_RGBs) | |
self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} | |
def random_flip(self, input_raw, target_rgb): | |
idx = np.random.randint(2) | |
input_raw = np.flip(input_raw,axis=idx).copy() | |
target_rgb = np.flip(target_rgb,axis=idx).copy() | |
return input_raw, target_rgb | |
def random_rotate(self, input_raw, target_rgb): | |
idx = np.random.randint(4) | |
input_raw = np.rot90(input_raw,k=idx) | |
target_rgb = np.rot90(target_rgb,k=idx) | |
return input_raw, target_rgb | |
def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False): | |
H, W, _ = input_raw.shape | |
rnd_h = random.randint(0, max(0, H - patch_size)) | |
rnd_w = random.randint(0, max(0, W - patch_size)) | |
patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] | |
if flow or demos: | |
patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] | |
else: | |
patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :] | |
return patch_input_raw, patch_target_rgb | |
def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): | |
input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos) | |
input_raw, target_rgb = self.random_rotate(input_raw,target_rgb) | |
input_raw, target_rgb = self.random_flip(input_raw,target_rgb) | |
return input_raw, target_rgb | |
def __len__(self): | |
return len(self.data['input_RAWs_WBs']) | |
def __getitem__(self, idx): | |
input_raw_wb_path = self.data['input_RAWs_WBs'][idx] | |
target_rgb_path = self.data['target_RGBs'][idx] | |
target_rgb_img = imread(target_rgb_path) | |
input_raw_wb = np.load(input_raw_wb_path) | |
input_raw_img = input_raw_wb['raw'] | |
wb = input_raw_wb['wb'] | |
wb = wb / wb.max() | |
input_raw_img = input_raw_img * wb[:-1] | |
self.patch_size = 256 | |
input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True) | |
if self.gamma: | |
norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) | |
input_raw_img = np.power(input_raw_img, 1/2.2) | |
else: | |
norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 | |
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
target_raw_img = input_raw_img.copy() | |
input_raw_img = self.np2tensor(input_raw_img).float() | |
target_rgb_img = self.np2tensor(target_rgb_img).float() | |
target_raw_img = self.np2tensor(target_raw_img).float() | |
sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, | |
'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} | |
return sample | |
class FiveKDatasetTest(BaseDataset): | |
def __init__(self, opt): | |
super().__init__(opt=opt) | |
self.patch_size = 256 | |
input_RAWs_WBs, target_RGBs = self.load(is_train=False) | |
assert len(input_RAWs_WBs) == len(target_RGBs) | |
self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} | |
def __len__(self): | |
return len(self.data['input_RAWs_WBs']) | |
def __getitem__(self, idx): | |
input_raw_wb_path = self.data['input_RAWs_WBs'][idx] | |
target_rgb_path = self.data['target_RGBs'][idx] | |
target_rgb_img = imread(target_rgb_path) | |
input_raw_wb = np.load(input_raw_wb_path) | |
input_raw_img = input_raw_wb['raw'] | |
wb = input_raw_wb['wb'] | |
wb = wb / wb.max() | |
input_raw_img = input_raw_img * wb[:-1] | |
if self.gamma: | |
norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) | |
input_raw_img = np.power(input_raw_img, 1/2.2) | |
else: | |
norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 | |
target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
target_raw_img = input_raw_img.copy() | |
input_raw_img = self.np2tensor(input_raw_img).float() | |
target_rgb_img = self.np2tensor(target_rgb_img).float() | |
target_raw_img = self.np2tensor(target_raw_img).float() | |
sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, | |
'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} | |
return sample | |