zhangap commited on
Commit
0bbd780
1 Parent(s): 1407acc

Delete testing_utils.py

Browse files
Files changed (1) hide show
  1. testing_utils.py +0 -210
testing_utils.py DELETED
@@ -1,210 +0,0 @@
1
- import argparse
2
- import json
3
- from PIL import Image
4
- from torchvision import transforms
5
- import torch.nn.functional as F
6
- from glob import glob
7
-
8
- import cv2
9
- import math
10
- import numpy as np
11
- import os
12
- import os.path as osp
13
- import random
14
- import time
15
- import torch
16
- from pathlib import Path
17
- from torch.utils import data as data
18
-
19
- from basicsr.utils import DiffJPEG, USMSharp
20
- from basicsr.utils.img_process_util import filter2D
21
- from basicsr.data.transforms import paired_random_crop, triplet_random_crop
22
- from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
23
-
24
- from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
25
- from basicsr.data.transforms import augment
26
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
27
- from basicsr.utils.registry import DATASET_REGISTRY
28
-
29
-
30
- def parse_args_paired_testing(input_args=None):
31
- """
32
- Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
33
- This function sets up an argument parser to handle various training options.
34
-
35
- Returns:
36
- argparse.Namespace: The parsed command-line arguments.
37
- """
38
- parser = argparse.ArgumentParser()
39
- parser.add_argument("--ref_path", type=str, default=None,)
40
- parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
41
- parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
42
-
43
- # details about the model architecture
44
- parser.add_argument("--sd_path")
45
- parser.add_argument("--de_net_path")
46
- parser.add_argument("--pretrained_path", type=str, default=None,)
47
- parser.add_argument("--revision", type=str, default=None,)
48
- parser.add_argument("--variant", type=str, default=None,)
49
- parser.add_argument("--tokenizer_name", type=str, default=None)
50
- parser.add_argument("--lora_rank_unet", default=32, type=int)
51
- parser.add_argument("--lora_rank_vae", default=16, type=int)
52
-
53
- parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
54
- parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
55
- parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
56
- parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
57
-
58
- parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
59
- parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
60
- parser.add_argument("--latent_tiled_size", type=int, default=96)
61
- parser.add_argument("--latent_tiled_overlap", type=int, default=32)
62
-
63
- parser.add_argument("--align_method", type=str, default="wavelet")
64
-
65
- parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
66
- parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
67
-
68
- # training details
69
- parser.add_argument("--output_dir", required=True)
70
- parser.add_argument("--cache_dir", default=None,)
71
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
72
- parser.add_argument("--resolution", type=int, default=512,)
73
- parser.add_argument("--checkpointing_steps", type=int, default=500,)
74
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
75
- parser.add_argument("--gradient_checkpointing", action="store_true",)
76
-
77
- parser.add_argument("--dataloader_num_workers", type=int, default=0,)
78
- parser.add_argument("--allow_tf32", action="store_true",
79
- help=(
80
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
81
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
82
- ),
83
- )
84
- parser.add_argument("--report_to", type=str, default="wandb",
85
- help=(
86
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
87
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
88
- ),
89
- )
90
- parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
91
- parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
92
- parser.add_argument("--set_grads_to_none", action="store_true",)
93
-
94
- parser.add_argument('--world_size', default=1, type=int,
95
- help='number of distributed processes')
96
- parser.add_argument('--local_rank', default=-1, type=int)
97
- parser.add_argument('--dist_url', default='env://',
98
- help='url used to set up distributed training')
99
-
100
- if input_args is not None:
101
- args = parser.parse_args(input_args)
102
- else:
103
- args = parser.parse_args()
104
-
105
- return args
106
-
107
-
108
- class PlainDataset(data.Dataset):
109
- """Modified dataset based on the dataset used for Real-ESRGAN model:
110
- Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
111
-
112
- It loads gt (Ground-Truth) images, and augments them.
113
- It also generates blur kernels and sinc kernels for generating low-quality images.
114
- Note that the low-quality images are processed in tensors on GPUS for faster processing.
115
-
116
- Args:
117
- opt (dict): Config for train datasets. It contains the following keys:
118
- dataroot_gt (str): Data root path for gt.
119
- meta_info (str): Path for meta information file.
120
- io_backend (dict): IO backend type and other kwarg.
121
- use_hflip (bool): Use horizontal flips.
122
- use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
123
- Please see more options in the codes.
124
- """
125
-
126
- def __init__(self, opt):
127
- super(PlainDataset, self).__init__()
128
- self.opt = opt
129
- self.file_client = None
130
- self.io_backend_opt = opt['io_backend']
131
-
132
- if 'image_type' not in opt:
133
- opt['image_type'] = 'png'
134
-
135
- # support multiple type of data: file path and meta data, remove support of lmdb
136
- self.lr_paths = []
137
- if 'lr_path' in opt:
138
- if isinstance(opt['lr_path'], str):
139
- self.lr_paths.extend(sorted(
140
- [str(x) for x in Path(opt['lr_path']).glob('*.png')] +
141
- [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
142
- [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
143
- ))
144
- else:
145
- self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
146
- if len(opt['lr_path']) > 1:
147
- for i in range(len(opt['lr_path'])-1):
148
- self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
149
-
150
- def __getitem__(self, index):
151
- if self.file_client is None:
152
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
153
-
154
- # -------------------------------- Load gt images -------------------------------- #
155
- # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
156
- lr_path = self.lr_paths[index]
157
-
158
- # avoid errors caused by high latency in reading files
159
- retry = 3
160
- while retry > 0:
161
- try:
162
- lr_img_bytes = self.file_client.get(lr_path, 'gt')
163
- except (IOError, OSError) as e:
164
- # logger = get_root_logger()
165
- # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
166
- # change another file to read
167
- index = random.randint(0, self.__len__()-1)
168
- lr_path = self.lr_paths[index]
169
- time.sleep(1) # sleep 1s for occasional server congestion
170
- else:
171
- break
172
- finally:
173
- retry -= 1
174
-
175
- img_lr = imfrombytes(lr_img_bytes, float32=True)
176
-
177
- # BGR to RGB, HWC to CHW, numpy to tensor
178
- img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
179
-
180
- return_d = {'lr': img_lr, 'lr_path': lr_path}
181
- return return_d
182
-
183
- def __len__(self):
184
- return len(self.lr_paths)
185
-
186
-
187
- def lr_proc(config, batch, device):
188
- im_lr = batch['lr'].cuda()
189
- im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
190
-
191
- ori_lr = im_lr
192
-
193
- im_lr = F.interpolate(
194
- im_lr,
195
- size=(im_lr.size(-2) * config.sf,
196
- im_lr.size(-1) * config.sf),
197
- mode='bicubic',
198
- )
199
-
200
- im_lr = im_lr.contiguous()
201
- im_lr = im_lr * 2 - 1.0
202
- im_lr = torch.clamp(im_lr, -1.0, 1.0)
203
-
204
- ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
205
-
206
- pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
207
- pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
208
- im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
209
-
210
- return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)