zhangap commited on
Commit
f76a7f4
·
verified ·
1 Parent(s): 1fb742e

Upload testing_utils.py

Browse files
Files changed (1) hide show
  1. testing_utils.py +210 -0
testing_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from glob import glob
7
+
8
+ import cv2
9
+ import math
10
+ import numpy as np
11
+ import os
12
+ import os.path as osp
13
+ import random
14
+ import time
15
+ import torch
16
+ from pathlib import Path
17
+ from torch.utils import data as data
18
+
19
+ from basicsr.utils import DiffJPEG, USMSharp
20
+ from basicsr.utils.img_process_util import filter2D
21
+ from basicsr.data.transforms import paired_random_crop, triplet_random_crop
22
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
23
+
24
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
25
+ from basicsr.data.transforms import augment
26
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
27
+ from basicsr.utils.registry import DATASET_REGISTRY
28
+
29
+
30
+ def parse_args_paired_testing(input_args=None):
31
+ """
32
+ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
33
+ This function sets up an argument parser to handle various training options.
34
+
35
+ Returns:
36
+ argparse.Namespace: The parsed command-line arguments.
37
+ """
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--ref_path", type=str, default=None,)
40
+ parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
41
+ parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
42
+
43
+ # details about the model architecture
44
+ parser.add_argument("--sd_path")
45
+ parser.add_argument("--de_net_path")
46
+ parser.add_argument("--pretrained_path", type=str, default=None,)
47
+ parser.add_argument("--revision", type=str, default=None,)
48
+ parser.add_argument("--variant", type=str, default=None,)
49
+ parser.add_argument("--tokenizer_name", type=str, default=None)
50
+ parser.add_argument("--lora_rank_unet", default=32, type=int)
51
+ parser.add_argument("--lora_rank_vae", default=16, type=int)
52
+
53
+ parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
54
+ parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
55
+ parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
56
+ parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
57
+
58
+ parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
59
+ parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
60
+ parser.add_argument("--latent_tiled_size", type=int, default=96)
61
+ parser.add_argument("--latent_tiled_overlap", type=int, default=32)
62
+
63
+ parser.add_argument("--align_method", type=str, default="wavelet")
64
+
65
+ parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
66
+ parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
67
+
68
+ # training details
69
+ parser.add_argument("--output_dir", required=True)
70
+ parser.add_argument("--cache_dir", default=None,)
71
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
72
+ parser.add_argument("--resolution", type=int, default=512,)
73
+ parser.add_argument("--checkpointing_steps", type=int, default=500,)
74
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
75
+ parser.add_argument("--gradient_checkpointing", action="store_true",)
76
+
77
+ parser.add_argument("--dataloader_num_workers", type=int, default=0,)
78
+ parser.add_argument("--allow_tf32", action="store_true",
79
+ help=(
80
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
81
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
82
+ ),
83
+ )
84
+ parser.add_argument("--report_to", type=str, default="wandb",
85
+ help=(
86
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
87
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
88
+ ),
89
+ )
90
+ parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
91
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
92
+ parser.add_argument("--set_grads_to_none", action="store_true",)
93
+
94
+ parser.add_argument('--world_size', default=1, type=int,
95
+ help='number of distributed processes')
96
+ parser.add_argument('--local_rank', default=-1, type=int)
97
+ parser.add_argument('--dist_url', default='env://',
98
+ help='url used to set up distributed training')
99
+
100
+ if input_args is not None:
101
+ args = parser.parse_args(input_args)
102
+ else:
103
+ args = parser.parse_args()
104
+
105
+ return args
106
+
107
+
108
+ class PlainDataset(data.Dataset):
109
+ """Modified dataset based on the dataset used for Real-ESRGAN model:
110
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
111
+
112
+ It loads gt (Ground-Truth) images, and augments them.
113
+ It also generates blur kernels and sinc kernels for generating low-quality images.
114
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
115
+
116
+ Args:
117
+ opt (dict): Config for train datasets. It contains the following keys:
118
+ dataroot_gt (str): Data root path for gt.
119
+ meta_info (str): Path for meta information file.
120
+ io_backend (dict): IO backend type and other kwarg.
121
+ use_hflip (bool): Use horizontal flips.
122
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
123
+ Please see more options in the codes.
124
+ """
125
+
126
+ def __init__(self, opt):
127
+ super(PlainDataset, self).__init__()
128
+ self.opt = opt
129
+ self.file_client = None
130
+ self.io_backend_opt = opt['io_backend']
131
+
132
+ if 'image_type' not in opt:
133
+ opt['image_type'] = 'png'
134
+
135
+ # support multiple type of data: file path and meta data, remove support of lmdb
136
+ self.lr_paths = []
137
+ if 'lr_path' in opt:
138
+ if isinstance(opt['lr_path'], str):
139
+ self.lr_paths.extend(sorted(
140
+ [str(x) for x in Path(opt['lr_path']).glob('*.png')] +
141
+ [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
142
+ [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
143
+ ))
144
+ else:
145
+ self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
146
+ if len(opt['lr_path']) > 1:
147
+ for i in range(len(opt['lr_path'])-1):
148
+ self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
149
+
150
+ def __getitem__(self, index):
151
+ if self.file_client is None:
152
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
153
+
154
+ # -------------------------------- Load gt images -------------------------------- #
155
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
156
+ lr_path = self.lr_paths[index]
157
+
158
+ # avoid errors caused by high latency in reading files
159
+ retry = 3
160
+ while retry > 0:
161
+ try:
162
+ lr_img_bytes = self.file_client.get(lr_path, 'gt')
163
+ except (IOError, OSError) as e:
164
+ # logger = get_root_logger()
165
+ # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
166
+ # change another file to read
167
+ index = random.randint(0, self.__len__()-1)
168
+ lr_path = self.lr_paths[index]
169
+ time.sleep(1) # sleep 1s for occasional server congestion
170
+ else:
171
+ break
172
+ finally:
173
+ retry -= 1
174
+
175
+ img_lr = imfrombytes(lr_img_bytes, float32=True)
176
+
177
+ # BGR to RGB, HWC to CHW, numpy to tensor
178
+ img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
179
+
180
+ return_d = {'lr': img_lr, 'lr_path': lr_path}
181
+ return return_d
182
+
183
+ def __len__(self):
184
+ return len(self.lr_paths)
185
+
186
+
187
+ def lr_proc(config, batch, device):
188
+ im_lr = batch['lr'].cuda()
189
+ im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
190
+
191
+ ori_lr = im_lr
192
+
193
+ im_lr = F.interpolate(
194
+ im_lr,
195
+ size=(im_lr.size(-2) * config.sf,
196
+ im_lr.size(-1) * config.sf),
197
+ mode='bicubic',
198
+ )
199
+
200
+ im_lr = im_lr.contiguous()
201
+ im_lr = im_lr * 2 - 1.0
202
+ im_lr = torch.clamp(im_lr, -1.0, 1.0)
203
+
204
+ ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
205
+
206
+ pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
207
+ pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
208
+ im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
209
+
210
+ return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)