|
|
|
import argparse |
|
import cv2 |
|
import torch |
|
import os, shutil, time |
|
import sys |
|
from multiprocessing import Process, Queue |
|
from os import path as osp |
|
from tqdm import tqdm |
|
import copy |
|
import warnings |
|
import gc |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
root_path = os.path.abspath('.') |
|
sys.path.append(root_path) |
|
from degradation.ESR.utils import np2tensor |
|
from degradation.ESR.degradations_functionality import * |
|
from degradation.ESR.diffjpeg import * |
|
from degradation.degradation_esr import degradation_v1 |
|
from opt import opt |
|
os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] |
|
|
|
|
|
|
|
def crop_process(path, crop_size, lr_dataset_path, output_index): |
|
''' crop the image here (also do usm here) |
|
Args: |
|
path (str): Path of the image |
|
crop_size (int): Crop size |
|
lr_dataset_path (str): LR dataset path folder name |
|
output_index (int): The index we used to store images |
|
Returns: |
|
output_index (int): The next index we need to use to store images |
|
''' |
|
|
|
|
|
img = cv2.imread(path) |
|
height, width = img.shape[0:2] |
|
|
|
res_store = [] |
|
crop_num = (height//crop_size)*(width//crop_size) |
|
|
|
|
|
shift_offset_h, shift_offset_w = 0, 0 |
|
|
|
|
|
|
|
choices = [i for i in range(crop_num)] |
|
shift_offset_h = 0 |
|
shift_offset_w = 0 |
|
|
|
|
|
for choice in choices: |
|
row_num = (width//crop_size) |
|
x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) |
|
|
|
res_store.append((x, y)) |
|
|
|
|
|
|
|
for (h, w) in res_store: |
|
cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] |
|
cropped_img = np.ascontiguousarray(cropped_img) |
|
cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
|
|
|
output_index += 1 |
|
|
|
return output_index |
|
|
|
|
|
|
|
def single_process(queue, opt, process_id): |
|
''' Multi Process instance |
|
Args: |
|
queue (multiprocessing.Queue): The input queue |
|
opt (dict): The setting we need to use |
|
process_id (int): The id we used to store temporary file |
|
''' |
|
|
|
|
|
obj_img = degradation_v1() |
|
|
|
while True: |
|
items = queue.get() |
|
if items == None: |
|
break |
|
input_path, store_path = items |
|
|
|
|
|
obj_img.reset_kernels(opt) |
|
|
|
|
|
img_bgr = cv2.imread(input_path) |
|
|
|
out = np2tensor(img_bgr) |
|
|
|
|
|
obj_img.degradate_process(out, opt, store_path, process_id, verbose = False) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def generate_low_res_esr(org_opt, verbose=False): |
|
''' Generate LR dataset from HR ones by ESR degradation |
|
Args: |
|
org_opt (dict): The setting we will use |
|
verbose (bool): Whether we print out some information |
|
''' |
|
|
|
|
|
input_folder = org_opt['input_folder'] |
|
save_folder = org_opt['save_folder'] |
|
if osp.exists(save_folder): |
|
shutil.rmtree(save_folder) |
|
if osp.exists("tmp"): |
|
shutil.rmtree("tmp") |
|
os.makedirs(save_folder) |
|
os.makedirs("tmp") |
|
if os.path.exists("datasets/degradation_log.txt"): |
|
os.remove("datasets/degradation_log.txt") |
|
|
|
|
|
|
|
input_img_lists, output_img_lists = [], [] |
|
for file in sorted(os.listdir(input_folder)): |
|
input_img_lists.append(osp.join(input_folder, file)) |
|
output_img_lists.append(osp.join("tmp", file)) |
|
assert(len(input_img_lists) == len(output_img_lists)) |
|
|
|
|
|
|
|
parallel_num = opt['parallel_num'] |
|
queue = Queue() |
|
|
|
|
|
|
|
for idx in range(len(input_img_lists)): |
|
|
|
queue.put((input_img_lists[idx], output_img_lists[idx])) |
|
|
|
|
|
|
|
Processes = [] |
|
for process_id in range(parallel_num): |
|
p1 = Process(target=single_process, args =(queue, opt, process_id, )) |
|
p1.start() |
|
Processes.append(p1) |
|
for _ in range(parallel_num): |
|
queue.put(None) |
|
|
|
|
|
|
|
for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"): |
|
while True: |
|
if os.path.exists(output_img_lists[idx]): |
|
break |
|
time.sleep(0.1) |
|
|
|
|
|
for process in Processes: |
|
process.join() |
|
|
|
|
|
|
|
|
|
output_index = 1 |
|
for img_name in sorted(os.listdir("tmp")): |
|
path = os.path.join("tmp", img_name) |
|
output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index) |
|
|
|
|
|
|
|
def main(args): |
|
opt['input_folder'] = args.input |
|
opt['save_folder'] = args.output |
|
|
|
generate_low_res_esr(opt) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder') |
|
parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder') |
|
args = parser.parse_args() |
|
|
|
main(args) |