OAOA commited on
Commit
2868b95
1 Parent(s): 703788e

add func of get_torch_dtype

Browse files
Files changed (2) hide show
  1. sampler_invsr.py +10 -2
  2. trainer.py +0 -1643
sampler_invsr.py CHANGED
@@ -10,8 +10,6 @@ from pathlib import Path
10
  from loguru import logger
11
  from omegaconf import OmegaConf
12
 
13
- from trainer import get_torch_dtype
14
-
15
  from utils import util_net
16
  from utils import util_image
17
  from utils import util_common
@@ -30,6 +28,16 @@ _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
30
  _negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
31
  'painting, drawing, sketch, oil painting'
32
 
 
 
 
 
 
 
 
 
 
 
33
  class BaseSampler:
34
  def __init__(self, configs):
35
  '''
 
10
  from loguru import logger
11
  from omegaconf import OmegaConf
12
 
 
 
13
  from utils import util_net
14
  from utils import util_image
15
  from utils import util_common
 
28
  _negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
29
  'painting, drawing, sketch, oil painting'
30
 
31
+ def get_torch_dtype(torch_dtype: str):
32
+ if torch_dtype == 'torch.float16':
33
+ return torch.float16
34
+ elif torch_dtype == 'torch.bfloat16':
35
+ return torch.bfloat16
36
+ elif torch_dtype == 'torch.float32':
37
+ return torch.float32
38
+ else:
39
+ raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
40
+
41
  class BaseSampler:
42
  def __init__(self, configs):
43
  '''
trainer.py DELETED
@@ -1,1643 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding:utf-8 -*-
3
- # Power by Zongsheng Yue 2022-05-18 13:04:06
4
-
5
- import os, sys, math, time, random, datetime
6
- import numpy as np
7
- from box import Box
8
- from pathlib import Path
9
- from loguru import logger
10
- from copy import deepcopy
11
- from omegaconf import OmegaConf
12
- from einops import rearrange
13
- from typing import Any, Dict, List, Optional, Tuple, Union
14
-
15
- from datapipe.datasets import create_dataset
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- import torch.utils.data as udata
21
- import torch.distributed as dist
22
- import torch.multiprocessing as mp
23
- import torchvision.utils as vutils
24
- from torch.nn.parallel import DistributedDataParallel as DDP
25
-
26
- from utils import util_net
27
- from utils import util_common
28
- from utils import util_image
29
- from utils.util_ops import append_dims
30
-
31
- import pyiqa
32
- from basicsr.utils import DiffJPEG, USMSharp
33
- from basicsr.utils.img_process_util import filter2D
34
- from basicsr.data.transforms import paired_random_crop
35
- from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
36
-
37
- from diffusers import EulerDiscreteScheduler
38
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
39
- from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps
40
-
41
- _base_seed = 10**6
42
- _INTERPOLATION_MODE = 'bicubic'
43
- _Latent_bound = {'min':-10.0, 'max':10.0}
44
- _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
45
- 'meticulous detailing, hyper sharpness, perfect without deformations'
46
- _negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
47
- 'painting, drawing, sketch, oil painting'
48
-
49
- class TrainerBase:
50
- def __init__(self, configs):
51
- self.configs = configs
52
-
53
- # setup distributed training: self.num_gpus, self.rank
54
- self.setup_dist()
55
-
56
- # setup seed
57
- self.setup_seed()
58
-
59
- def setup_dist(self):
60
- num_gpus = torch.cuda.device_count()
61
-
62
- if num_gpus > 1:
63
- if mp.get_start_method(allow_none=True) is None:
64
- mp.set_start_method('spawn')
65
- rank = int(os.environ['LOCAL_RANK'])
66
- torch.cuda.set_device(rank % num_gpus)
67
- dist.init_process_group(
68
- timeout=datetime.timedelta(seconds=3600),
69
- backend='nccl',
70
- init_method='env://',
71
- )
72
-
73
- self.num_gpus = num_gpus
74
- self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0
75
-
76
- def setup_seed(self, seed=None, global_seeding=None):
77
- if seed is None:
78
- seed = self.configs.train.get('seed', 12345)
79
- if global_seeding is None:
80
- global_seeding = self.configs.train.get('global_seeding', False)
81
- if not global_seeding:
82
- seed += self.rank
83
- torch.cuda.manual_seed(seed)
84
- else:
85
- torch.cuda.manual_seed_all(seed)
86
- random.seed(seed)
87
- np.random.seed(seed)
88
- torch.manual_seed(seed)
89
-
90
- def init_logger(self):
91
- if self.configs.resume:
92
- assert self.configs.resume.endswith(".pth")
93
- save_dir = Path(self.configs.resume).parents[1]
94
- project_id = save_dir.name
95
- else:
96
- project_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
97
- save_dir = Path(self.configs.save_dir) / project_id
98
- if not save_dir.exists() and self.rank == 0:
99
- save_dir.mkdir(parents=True)
100
-
101
- # setting log counter
102
- if self.rank == 0:
103
- self.log_step = {phase: 1 for phase in ['train', 'val']}
104
- self.log_step_img = {phase: 1 for phase in ['train', 'val']}
105
-
106
- # text logging
107
- logtxet_path = save_dir / 'training.log'
108
- if self.rank == 0:
109
- if logtxet_path.exists():
110
- assert self.configs.resume
111
- self.logger = logger
112
- self.logger.remove()
113
- self.logger.add(logtxet_path, format="{message}", mode='a', level='INFO')
114
- self.logger.add(sys.stdout, format="{message}")
115
-
116
- # tensorboard logging
117
- log_dir = save_dir / 'tf_logs'
118
- self.tf_logging = self.configs.train.tf_logging
119
- if self.rank == 0 and self.tf_logging:
120
- if not log_dir.exists():
121
- log_dir.mkdir()
122
- self.writer = SummaryWriter(str(log_dir))
123
-
124
- # checkpoint saving
125
- ckpt_dir = save_dir / 'ckpts'
126
- self.ckpt_dir = ckpt_dir
127
- if self.rank == 0 and (not ckpt_dir.exists()):
128
- ckpt_dir.mkdir()
129
- if 'ema_rate' in self.configs.train:
130
- self.ema_rate = self.configs.train.ema_rate
131
- assert isinstance(self.ema_rate, float), "Ema rate must be a float number"
132
- ema_ckpt_dir = save_dir / 'ema_ckpts'
133
- self.ema_ckpt_dir = ema_ckpt_dir
134
- if self.rank == 0 and (not ema_ckpt_dir.exists()):
135
- ema_ckpt_dir.mkdir()
136
-
137
- # save images into local disk
138
- self.local_logging = self.configs.train.local_logging
139
- if self.rank == 0 and self.local_logging:
140
- image_dir = save_dir / 'images'
141
- if not image_dir.exists():
142
- (image_dir / 'train').mkdir(parents=True)
143
- (image_dir / 'val').mkdir(parents=True)
144
- self.image_dir = image_dir
145
-
146
- # logging the configurations
147
- if self.rank == 0:
148
- self.logger.info(OmegaConf.to_yaml(self.configs))
149
-
150
- def close_logger(self):
151
- if self.rank == 0 and self.tf_logging:
152
- self.writer.close()
153
-
154
- def resume_from_ckpt(self):
155
- if self.configs.resume:
156
- assert self.configs.resume.endswith(".pth") and os.path.isfile(self.configs.resume)
157
-
158
- if self.rank == 0:
159
- self.logger.info(f"=> Loading checkpoint from {self.configs.resume}")
160
- ckpt = torch.load(self.configs.resume, map_location=f"cuda:{self.rank}")
161
- util_net.reload_model(self.model, ckpt['state_dict'])
162
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
163
- util_net.reload_model(self.discriminator, ckpt['state_dict_dis'])
164
- torch.cuda.empty_cache()
165
-
166
- # learning rate scheduler
167
- self.iters_start = ckpt['iters_start']
168
- for ii in range(1, self.iters_start+1):
169
- self.adjust_lr(ii)
170
-
171
- # logging
172
- if self.rank == 0:
173
- self.log_step = ckpt['log_step']
174
- self.log_step_img = ckpt['log_step_img']
175
-
176
- # EMA model
177
- if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'):
178
- ema_ckpt_path = self.ema_ckpt_dir / ("ema_"+Path(self.configs.resume).name)
179
- self.logger.info(f"=> Loading EMA checkpoint from {str(ema_ckpt_path)}")
180
- ema_ckpt = torch.load(ema_ckpt_path, map_location=f"cuda:{self.rank}")
181
- util_net.reload_model(self.ema_model, ema_ckpt)
182
- torch.cuda.empty_cache()
183
-
184
- # AMP scaler
185
- if self.amp_scaler is not None:
186
- if "amp_scaler" in ckpt:
187
- self.amp_scaler.load_state_dict(ckpt["amp_scaler"])
188
- if self.rank == 0:
189
- self.logger.info("Loading scaler from resumed state...")
190
- if self.configs.get('discriminator', None) is not None:
191
- if "amp_scaler_dis" in ckpt:
192
- self.amp_scaler_dis.load_state_dict(ckpt["amp_scaler_dis"])
193
- if self.rank == 0:
194
- self.logger.info("Loading scaler (discriminator) from resumed state...")
195
-
196
- # reset the seed
197
- self.setup_seed(seed=self.iters_start)
198
- else:
199
- self.iters_start = 0
200
-
201
- def setup_optimizaton(self):
202
- self.optimizer = torch.optim.AdamW(self.model.parameters(),
203
- lr=self.configs.train.lr,
204
- weight_decay=self.configs.train.weight_decay)
205
-
206
- # amp settings
207
- self.amp_scaler = torch.amp.GradScaler('cuda') if self.configs.train.use_amp else None
208
-
209
- if self.configs.train.lr_schedule == 'cosin':
210
- self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
211
- optimizer=self.optimizer,
212
- T_max=self.configs.train.iterations - self.configs.train.warmup_iterations,
213
- eta_min=self.configs.train.lr_min,
214
- )
215
-
216
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
217
- self.optimizer_dis = torch.optim.Adam(
218
- self.discriminator.parameters(),
219
- lr=self.configs.train.lr_dis,
220
- weight_decay=self.configs.train.weight_decay_dis,
221
- )
222
- self.amp_scaler_dis = torch.amp.GradScaler('cuda') if self.configs.train.use_amp else None
223
-
224
- def prepare_compiling(self):
225
- # https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3#stable-diffusion-3
226
- if not hasattr(self, "prepare_compiling_well") or (not self.prepare_compiling_well):
227
- torch.set_float32_matmul_precision("high")
228
- torch._inductor.config.conv_1x1_as_mm = True
229
- torch._inductor.config.coordinate_descent_tuning = True
230
- torch._inductor.config.epilogue_fusion = False
231
- torch._inductor.config.coordinate_descent_check_all_directions = True
232
- self.prepare_compiling_well = True
233
-
234
- def build_model(self):
235
- if self.configs.train.get("compile", True):
236
- self.prepare_compiling()
237
-
238
- params = self.configs.model.get('params', dict)
239
- model = util_common.get_obj_from_str(self.configs.model.target)(**params)
240
- model.cuda()
241
- if not self.configs.train.start_mode: # Loading the starting model for evaluation
242
- self.start_model = deepcopy(model)
243
- assert self.configs.model.ckpt_start_path is not None
244
- ckpt_start_path = self.configs.model.ckpt_start_path
245
- if self.rank == 0:
246
- self.logger.info(f"Loading the starting model from {ckpt_start_path}")
247
- ckpt = torch.load(ckpt_start_path, map_location=f"cuda:{self.rank}")
248
- if 'state_dict' in ckpt:
249
- ckpt = ckpt['state_dict']
250
- util_net.reload_model(self.start_model, ckpt)
251
- self.freeze_model(self.start_model)
252
- self.start_model.eval()
253
- # delete the started timestep
254
- start_timestep = max(self.configs.train.timesteps)
255
- self.configs.train.timesteps.remove(start_timestep)
256
- # end_timestep = min(self.configs.train.timesteps)
257
- # self.configs.train.timesteps.remove(end_timestep)
258
-
259
- # setting the training model
260
- if self.configs.model.get('ckpt_path', None): # initialize if necessary
261
- ckpt_path = self.configs.model.ckpt_path
262
- if self.rank == 0:
263
- self.logger.info(f"Initializing model from {ckpt_path}")
264
- ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}")
265
- if 'state_dict' in ckpt:
266
- ckpt = ckpt['state_dict']
267
- util_net.reload_model(model, ckpt)
268
- if self.configs.model.get("compile", False):
269
- if self.rank == 0:
270
- self.logger.info("Compile the model...")
271
- model.to(memory_format=torch.channels_last)
272
- model = torch.compile(model, mode="max-autotune", fullgraph=False)
273
- if self.num_gpus > 1:
274
- model = DDP(model, device_ids=[self.rank,]) # wrap the network
275
- if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'):
276
- self.ema_model = deepcopy(model)
277
- self.freeze_model(self.ema_model)
278
- self.model = model
279
-
280
- # discriminator if necessary
281
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
282
- assert hasattr(self.configs, 'discriminator')
283
- params = self.configs.discriminator.get('params', dict)
284
- discriminator = util_common.get_obj_from_str(self.configs.discriminator.target)(**params)
285
- discriminator.cuda()
286
- if self.configs.discriminator.get("compile", False):
287
- if self.rank == 0:
288
- self.logger.info("Compile the discriminator...")
289
- discriminator.to(memory_format=torch.channels_last)
290
- discriminator = torch.compile(discriminator, mode="max-autotune", fullgraph=False)
291
- if self.num_gpus > 1:
292
- discriminator = DDP(discriminator, device_ids=[self.rank,]) # wrap the network
293
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
294
- if self.configs.discriminator.enable_grad_checkpoint:
295
- if self.rank == 0:
296
- self.logger.info("Activating gradient checkpointing for discriminator...")
297
- self.set_grad_checkpointing(discriminator)
298
- self.discriminator = discriminator
299
-
300
- # build the stable diffusion
301
- params = dict(self.configs.sd_pipe.params)
302
- torch_dtype = params.pop('torch_dtype')
303
- params['torch_dtype'] = get_torch_dtype(torch_dtype)
304
- # loading the fp16 robust vae for sdxl: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
305
- if self.configs.get('vae_fp16', None) is not None:
306
- params_vae = dict(self.configs.vae_fp16.params)
307
- params_vae['torch_dtype'] = torch.float16
308
- pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path
309
- if self.rank == 0:
310
- self.logger.info(f'Loading improved vae from {pipe_id}...')
311
- vae_pipe = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(**params_vae)
312
- if self.rank == 0:
313
- self.logger.info('Loaded Done')
314
- params['vae'] = vae_pipe
315
- if ("StableDiffusion3" in self.configs.sd_pipe.target.split('.')[-1]
316
- and self.configs.sd_pipe.get("model_quantization", False)):
317
- if self.rank == 0:
318
- self.logger.info(f'Loading the quantized transformer for SD3...')
319
- nf4_config = BitsAndBytesConfig(
320
- load_in_4bit=True,
321
- bnb_4bit_quant_type="nf4",
322
- bnb_4bit_compute_dtype=torch.bfloat16
323
- )
324
- params_model = dict(self.configs.model_nf4.params)
325
- torch_dtype = params_model.pop('torch_dtype')
326
- params_model['torch_dtype'] = get_torch_dtype(torch_dtype)
327
- params_model['quantization_config'] = nf4_config
328
- model_nf4 = util_common.get_obj_from_str(self.configs.model_nf4.target).from_pretrained(
329
- **params_model
330
- )
331
- params['transformer'] = model_nf4
332
- sd_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)
333
- if self.configs.get('scheduler', None) is not None:
334
- pipe_id = self.configs.scheduler.target.split('.')[-1]
335
- if self.rank == 0:
336
- self.logger.info(f'Loading scheduler of {pipe_id}...')
337
- sd_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(
338
- sd_pipe.scheduler.config
339
- )
340
- if self.rank == 0:
341
- self.logger.info('Loaded Done')
342
- if ("StableDiffusion3" in self.configs.sd_pipe.target.split('.')[-1]
343
- and self.configs.sd_pipe.get("model_quantization", False)):
344
- sd_pipe.enable_model_cpu_offload(gpu_id=self.rank,device='cuda')
345
- else:
346
- sd_pipe.to(f"cuda:{self.rank}")
347
- # freezing model parameters
348
- if hasattr(sd_pipe, 'unet'):
349
- self.freeze_model(sd_pipe.unet)
350
- if hasattr(sd_pipe, 'transformer'):
351
- self.freeze_model(sd_pipe.transformer)
352
- self.freeze_model(sd_pipe.vae)
353
- # compiling
354
- if self.configs.sd_pipe.get('compile', True):
355
- if self.rank == 0:
356
- self.logger.info('Compile the SD model...')
357
- sd_pipe.set_progress_bar_config(disable=True)
358
- if hasattr(sd_pipe, 'unet'):
359
- sd_pipe.unet.to(memory_format=torch.channels_last)
360
- sd_pipe.unet = torch.compile(sd_pipe.unet, mode="max-autotune", fullgraph=False)
361
- if hasattr(sd_pipe, 'transformer'):
362
- sd_pipe.transformer.to(memory_format=torch.channels_last)
363
- sd_pipe.transformer = torch.compile(sd_pipe.transformer, mode="max-autotune", fullgraph=False)
364
- sd_pipe.vae.to(memory_format=torch.channels_last)
365
- sd_pipe.vae = torch.compile(sd_pipe.vae, mode="max-autotune", fullgraph=True)
366
- # setting gradient checkpoint for vae
367
- if self.configs.sd_pipe.get("enable_grad_checkpoint_vae", True):
368
- if self.rank == 0:
369
- self.logger.info("Activating gradient checkpointing for VAE...")
370
- sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.encoder)
371
- sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.decoder)
372
- # setting gradient checkpoint for diffusion model
373
- if self.configs.sd_pipe.enable_grad_checkpoint:
374
- if self.rank == 0:
375
- self.logger.info("Activating gradient checkpointing for SD...")
376
- if hasattr(sd_pipe, 'unet'):
377
- self.set_grad_checkpointing(sd_pipe.unet)
378
- if hasattr(sd_pipe, 'transformer'):
379
- self.set_grad_checkpointing(sd_pipe.transformer)
380
- self.sd_pipe = sd_pipe
381
-
382
- # latent LPIPS loss
383
- if self.configs.train.loss_coef.get('llpips', 0) > 0:
384
- params = self.configs.llpips.get('params', dict)
385
- llpips_loss = util_common.get_obj_from_str(self.configs.llpips.target)(**params)
386
- llpips_loss.cuda()
387
- self.freeze_model(llpips_loss)
388
-
389
- # loading the pre-trained model
390
- ckpt_path = self.configs.llpips.ckpt_path
391
- self.load_model(llpips_loss, ckpt_path, tag='latent lpips')
392
-
393
- if self.configs.llpips.get("compile", True):
394
- if self.rank == 0:
395
- self.logger.info('Compile the llpips loss...')
396
- llpips_loss.to(memory_format=torch.channels_last)
397
- llpips_loss = torch.compile(llpips_loss, mode="max-autotune", fullgraph=True)
398
-
399
- self.llpips_loss = llpips_loss
400
-
401
- # model information
402
- self.print_model_info()
403
-
404
- torch.cuda.empty_cache()
405
-
406
- def set_grad_checkpointing(self, model):
407
- if hasattr(model, 'down_blocks'):
408
- for module in model.down_blocks:
409
- module.gradient_checkpointing = True
410
- module.training = True
411
-
412
- if hasattr(model, 'up_blocks'):
413
- for module in model.up_blocks:
414
- module.gradient_checkpointing = True
415
- module.training = True
416
-
417
- if hasattr(model, 'mid_blocks'):
418
- model.mid_block.gradient_checkpointing = True
419
- model.mid_block.training = True
420
-
421
- def build_dataloader(self):
422
- def _wrap_loader(loader):
423
- while True: yield from loader
424
-
425
- # make datasets
426
- datasets = {'train': create_dataset(self.configs.data.get('train', dict)), }
427
- if hasattr(self.configs.data, 'val') and self.rank == 0:
428
- datasets['val'] = create_dataset(self.configs.data.get('val', dict))
429
- if self.rank == 0:
430
- for phase in datasets.keys():
431
- length = len(datasets[phase])
432
- self.logger.info('Number of images in {:s} data set: {:d}'.format(phase, length))
433
-
434
- # make dataloaders
435
- if self.num_gpus > 1:
436
- sampler = udata.distributed.DistributedSampler(
437
- datasets['train'],
438
- num_replicas=self.num_gpus,
439
- rank=self.rank,
440
- )
441
- else:
442
- sampler = None
443
- dataloaders = {'train': _wrap_loader(udata.DataLoader(
444
- datasets['train'],
445
- batch_size=self.configs.train.batch // self.num_gpus,
446
- shuffle=False if self.num_gpus > 1 else True,
447
- drop_last=True,
448
- num_workers=min(self.configs.train.num_workers, 4),
449
- pin_memory=True,
450
- prefetch_factor=self.configs.train.get('prefetch_factor', 2),
451
- worker_init_fn=my_worker_init_fn,
452
- sampler=sampler,
453
- ))}
454
- if hasattr(self.configs.data, 'val') and self.rank == 0:
455
- dataloaders['val'] = udata.DataLoader(datasets['val'],
456
- batch_size=self.configs.validate.batch,
457
- shuffle=False,
458
- drop_last=False,
459
- num_workers=0,
460
- pin_memory=True,
461
- )
462
-
463
- self.datasets = datasets
464
- self.dataloaders = dataloaders
465
- self.sampler = sampler
466
-
467
- def print_model_info(self):
468
- if self.rank == 0:
469
- num_params = util_net.calculate_parameters(self.model) / 1000**2
470
- # self.logger.info("Detailed network architecture:")
471
- # self.logger.info(self.model.__repr__())
472
- if self.configs.train.get('use_fsdp', False):
473
- num_params *= self.num_gpus
474
- self.logger.info(f"Number of parameters: {num_params:.2f}M")
475
-
476
- if hasattr(self, 'discriminator'):
477
- num_params = util_net.calculate_parameters(self.discriminator) / 1000**2
478
- self.logger.info(f"Number of parameters in discriminator: {num_params:.2f}M")
479
-
480
- def prepare_data(self, data, dtype=torch.float32, phase='train'):
481
- data = {key:value.cuda().to(dtype=dtype) for key, value in data.items()}
482
- return data
483
-
484
- def validation(self):
485
- pass
486
-
487
- def train(self):
488
- self.init_logger() # setup logger: self.logger
489
-
490
- self.build_dataloader() # prepare data: self.dataloaders, self.datasets, self.sampler
491
-
492
- self.build_model() # build model: self.model, self.loss
493
-
494
- self.setup_optimizaton() # setup optimization: self.optimzer, self.sheduler
495
-
496
- self.resume_from_ckpt() # resume if necessary
497
-
498
- self.model.train()
499
- num_iters_epoch = math.ceil(len(self.datasets['train']) / self.configs.train.batch)
500
- for ii in range(self.iters_start, self.configs.train.iterations):
501
- self.current_iters = ii + 1
502
-
503
- # prepare data
504
- data = self.prepare_data(next(self.dataloaders['train']), phase='train')
505
-
506
- # training phase
507
- self.training_step(data)
508
-
509
- # update ema model
510
- if hasattr(self.configs.train, 'ema_rate') and self.rank == 0:
511
- self.update_ema_model()
512
-
513
- # validation phase
514
- if ((ii+1) % self.configs.train.save_freq == 0 and
515
- 'val' in self.dataloaders and
516
- self.rank == 0
517
- ):
518
- self.validation()
519
-
520
- #update learning rate
521
- self.adjust_lr()
522
-
523
- # save checkpoint
524
- if (ii+1) % self.configs.train.save_freq == 0 and self.rank == 0:
525
- self.save_ckpt()
526
-
527
- if (ii+1) % num_iters_epoch == 0 and self.sampler is not None:
528
- self.sampler.set_epoch(ii+1)
529
-
530
- # close the tensorboard
531
- self.close_logger()
532
-
533
- def adjust_lr(self, current_iters=None):
534
- base_lr = self.configs.train.lr
535
- warmup_steps = self.configs.train.get("warmup_iterations", 0)
536
- current_iters = self.current_iters if current_iters is None else current_iters
537
- if current_iters <= warmup_steps:
538
- for params_group in self.optimizer.param_groups:
539
- params_group['lr'] = (current_iters / warmup_steps) * base_lr
540
- else:
541
- if hasattr(self, 'lr_scheduler'):
542
- self.lr_scheduler.step()
543
-
544
- def save_ckpt(self):
545
- ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters)
546
- ckpt = {
547
- 'iters_start': self.current_iters,
548
- 'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']},
549
- 'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']},
550
- 'state_dict': self.model.state_dict(),
551
- }
552
- if self.amp_scaler is not None:
553
- ckpt['amp_scaler'] = self.amp_scaler.state_dict()
554
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
555
- ckpt['state_dict_dis'] = self.discriminator.state_dict()
556
- if self.amp_scaler_dis is not None:
557
- ckpt['amp_scaler_dis'] = self.amp_scaler_dis.state_dict()
558
- torch.save(ckpt, ckpt_path)
559
- if hasattr(self.configs.train, 'ema_rate'):
560
- ema_ckpt_path = self.ema_ckpt_dir / 'ema_model_{:d}.pth'.format(self.current_iters)
561
- torch.save(self.ema_model.state_dict(), ema_ckpt_path)
562
-
563
- def logging_image(self, im_tensor, tag, phase, add_global_step=False, nrow=8):
564
- """
565
- Args:
566
- im_tensor: b x c x h x w tensor
567
- im_tag: str
568
- phase: 'train' or 'val'
569
- nrow: number of displays in each row
570
- """
571
- assert self.tf_logging or self.local_logging
572
- im_tensor = vutils.make_grid(im_tensor, nrow=nrow, normalize=True, scale_each=True) # c x H x W
573
- if self.local_logging:
574
- im_path = str(self.image_dir / phase / f"{tag}-{self.log_step_img[phase]}.png")
575
- im_np = im_tensor.cpu().permute(1,2,0).numpy()
576
- util_image.imwrite(im_np, im_path)
577
- if self.tf_logging:
578
- self.writer.add_image(
579
- f"{phase}-{tag}-{self.log_step_img[phase]}",
580
- im_tensor,
581
- self.log_step_img[phase],
582
- )
583
- if add_global_step:
584
- self.log_step_img[phase] += 1
585
-
586
- def logging_text(self, text_list, phase):
587
- """
588
- Args:
589
- text_list: (b,) list
590
- phase: 'train' or 'val'
591
- """
592
- assert self.local_logging
593
- if self.local_logging:
594
- text_path = str(self.image_dir / phase / f"text-{self.log_step_img[phase]}.txt")
595
- with open(text_path, 'w') as ff:
596
- for text in text_list:
597
- ff.write(text + '\n')
598
-
599
- def logging_metric(self, metrics, tag, phase, add_global_step=False):
600
- """
601
- Args:
602
- metrics: dict
603
- tag: str
604
- phase: 'train' or 'val'
605
- """
606
- if self.tf_logging:
607
- tag = f"{phase}-{tag}"
608
- if isinstance(metrics, dict):
609
- self.writer.add_scalars(tag, metrics, self.log_step[phase])
610
- else:
611
- self.writer.add_scalar(tag, metrics, self.log_step[phase])
612
- if add_global_step:
613
- self.log_step[phase] += 1
614
- else:
615
- pass
616
-
617
- def load_model(self, model, ckpt_path=None, tag='model'):
618
- if self.rank == 0:
619
- self.logger.info(f'Loading {tag} from {ckpt_path}...')
620
- ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}")
621
- if 'state_dict' in ckpt:
622
- ckpt = ckpt['state_dict']
623
- util_net.reload_model(model, ckpt)
624
- if self.rank == 0:
625
- self.logger.info('Loaded Done')
626
-
627
- def freeze_model(self, net):
628
- for params in net.parameters():
629
- params.requires_grad = False
630
-
631
- def unfreeze_model(self, net):
632
- for params in net.parameters():
633
- params.requires_grad = True
634
-
635
- @torch.no_grad()
636
- def update_ema_model(self):
637
- decay = min(self.configs.train.ema_rate, (1 + self.current_iters) / (10 + self.current_iters))
638
- target_params = dict(self.model.named_parameters())
639
- # if hasattr(self.configs.train, 'ema_rate'):
640
- # with FSDP.summon_full_params(self.model, writeback=True):
641
- # target_params = dict(self.model.named_parameters())
642
- # else:
643
- # target_params = dict(self.model.named_parameters())
644
-
645
- one_minus_decay = 1.0 - decay
646
-
647
- for key, source_value in self.ema_model.named_parameters():
648
- target_value = target_params[key]
649
- if target_value.requires_grad:
650
- source_value.sub_(one_minus_decay * (source_value - target_value.data))
651
-
652
- class TrainerBaseSR(TrainerBase):
653
- @torch.no_grad()
654
- def _dequeue_and_enqueue(self):
655
- """It is the training pair pool for increasing the diversity in a batch.
656
-
657
- Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
658
- batch could not have different resize scaling factors. Therefore, we employ this training pair pool
659
- to increase the degradation diversity in a batch.
660
- """
661
- # initialize
662
- b, c, h, w = self.lq.size()
663
- if not hasattr(self, 'queue_size'):
664
- self.queue_size = self.configs.degradation.get('queue_size', b*10)
665
- if not hasattr(self, 'queue_lr'):
666
- assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
667
- self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
668
- _, c, h, w = self.gt.size()
669
- self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
670
- _, c, h, w = self.gt_latent.size()
671
- self.queue_gt_latent = torch.zeros(self.queue_size, c, h, w).cuda()
672
- self.queue_txt = ["", ] * self.queue_size
673
- self.queue_ptr = 0
674
- if self.queue_ptr == self.queue_size: # the pool is full
675
- # do dequeue and enqueue
676
- # shuffle
677
- idx = torch.randperm(self.queue_size)
678
- self.queue_lr = self.queue_lr[idx]
679
- self.queue_gt = self.queue_gt[idx]
680
- self.queue_gt_latent = self.queue_gt_latent[idx]
681
- self.queue_txt = [self.queue_txt[ii] for ii in idx]
682
- # get first b samples
683
- lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
684
- gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
685
- gt_latent_dequeue = self.queue_gt_latent[0:b, :, :, :].clone()
686
- txt_dequeue = deepcopy(self.queue_txt[0:b])
687
- # update the queue
688
- self.queue_lr[0:b, :, :, :] = self.lq.clone()
689
- self.queue_gt[0:b, :, :, :] = self.gt.clone()
690
- self.queue_gt_latent[0:b, :, :, :] = self.gt_latent.clone()
691
- self.queue_txt[0:b] = deepcopy(self.txt)
692
-
693
- self.lq = lq_dequeue
694
- self.gt = gt_dequeue
695
- self.gt_latent = gt_latent_dequeue
696
- self.txt = txt_dequeue
697
- else:
698
- # only do enqueue
699
- self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
700
- self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
701
- self.queue_gt_latent[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt_latent.clone()
702
- self.queue_txt[self.queue_ptr:self.queue_ptr + b] = deepcopy(self.txt)
703
- self.queue_ptr = self.queue_ptr + b
704
-
705
- @torch.no_grad()
706
- def prepare_data(self, data, phase='train'):
707
- if phase == 'train' and self.configs.data.get(phase).get('type') == 'realesrgan':
708
- if not hasattr(self, 'jpeger'):
709
- self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
710
- if (not hasattr(self, 'sharpener')) and self.configs.degradation.get('use_sharp', False):
711
- self.sharpener = USMSharp().cuda()
712
-
713
- im_gt = data['gt'].cuda()
714
- kernel1 = data['kernel1'].cuda()
715
- kernel2 = data['kernel2'].cuda()
716
- sinc_kernel = data['sinc_kernel'].cuda()
717
-
718
- ori_h, ori_w = im_gt.size()[2:4]
719
- if isinstance(self.configs.degradation.sf, int):
720
- sf = self.configs.degradation.sf
721
- else:
722
- assert len(self.configs.degradation.sf) == 2
723
- sf = random.uniform(*self.configs.degradation.sf)
724
-
725
- if self.configs.degradation.use_sharp:
726
- im_gt = self.sharpener(im_gt)
727
-
728
- # ----------------------- The first degradation process ----------------------- #
729
- # blur
730
- out = filter2D(im_gt, kernel1)
731
- # random resize
732
- updown_type = random.choices(
733
- ['up', 'down', 'keep'],
734
- self.configs.degradation['resize_prob'],
735
- )[0]
736
- if updown_type == 'up':
737
- scale = random.uniform(1, self.configs.degradation['resize_range'][1])
738
- elif updown_type == 'down':
739
- scale = random.uniform(self.configs.degradation['resize_range'][0], 1)
740
- else:
741
- scale = 1
742
- mode = random.choice(['area', 'bilinear', 'bicubic'])
743
- out = F.interpolate(out, scale_factor=scale, mode=mode)
744
- # add noise
745
- gray_noise_prob = self.configs.degradation['gray_noise_prob']
746
- if random.random() < self.configs.degradation['gaussian_noise_prob']:
747
- out = random_add_gaussian_noise_pt(
748
- out,
749
- sigma_range=self.configs.degradation['noise_range'],
750
- clip=True,
751
- rounds=False,
752
- gray_prob=gray_noise_prob,
753
- )
754
- else:
755
- out = random_add_poisson_noise_pt(
756
- out,
757
- scale_range=self.configs.degradation['poisson_scale_range'],
758
- gray_prob=gray_noise_prob,
759
- clip=True,
760
- rounds=False)
761
- # JPEG compression
762
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range'])
763
- out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
764
- out = self.jpeger(out, quality=jpeg_p)
765
-
766
- # ----------------------- The second degradation process ----------------------- #
767
- if random.random() < self.configs.degradation['second_order_prob']:
768
- # blur
769
- if random.random() < self.configs.degradation['second_blur_prob']:
770
- out = filter2D(out, kernel2)
771
- # random resize
772
- updown_type = random.choices(
773
- ['up', 'down', 'keep'],
774
- self.configs.degradation['resize_prob2'],
775
- )[0]
776
- if updown_type == 'up':
777
- scale = random.uniform(1, self.configs.degradation['resize_range2'][1])
778
- elif updown_type == 'down':
779
- scale = random.uniform(self.configs.degradation['resize_range2'][0], 1)
780
- else:
781
- scale = 1
782
- mode = random.choice(['area', 'bilinear', 'bicubic'])
783
- out = F.interpolate(
784
- out,
785
- size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
786
- mode=mode,
787
- )
788
- # add noise
789
- gray_noise_prob = self.configs.degradation['gray_noise_prob2']
790
- if random.random() < self.configs.degradation['gaussian_noise_prob2']:
791
- out = random_add_gaussian_noise_pt(
792
- out,
793
- sigma_range=self.configs.degradation['noise_range2'],
794
- clip=True,
795
- rounds=False,
796
- gray_prob=gray_noise_prob,
797
- )
798
- else:
799
- out = random_add_poisson_noise_pt(
800
- out,
801
- scale_range=self.configs.degradation['poisson_scale_range2'],
802
- gray_prob=gray_noise_prob,
803
- clip=True,
804
- rounds=False,
805
- )
806
-
807
- # JPEG compression + the final sinc filter
808
- # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
809
- # as one operation.
810
- # We consider two orders:
811
- # 1. [resize back + sinc filter] + JPEG compression
812
- # 2. JPEG compression + [resize back + sinc filter]
813
- # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
814
- if random.random() < 0.5:
815
- # resize back + the final sinc filter
816
- mode = random.choice(['area', 'bilinear', 'bicubic'])
817
- out = F.interpolate(
818
- out,
819
- size=(ori_h // sf, ori_w // sf),
820
- mode=mode,
821
- )
822
- out = filter2D(out, sinc_kernel)
823
- # JPEG compression
824
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
825
- out = torch.clamp(out, 0, 1)
826
- out = self.jpeger(out, quality=jpeg_p)
827
- else:
828
- # JPEG compression
829
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
830
- out = torch.clamp(out, 0, 1)
831
- out = self.jpeger(out, quality=jpeg_p)
832
- # resize back + the final sinc filter
833
- mode = random.choice(['area', 'bilinear', 'bicubic'])
834
- out = F.interpolate(
835
- out,
836
- size=(ori_h // sf, ori_w // sf),
837
- mode=mode,
838
- )
839
- out = filter2D(out, sinc_kernel)
840
-
841
- # resize back
842
- if self.configs.degradation.resize_back:
843
- out = F.interpolate(out, size=(ori_h, ori_w), mode=_INTERPOLATION_MODE)
844
-
845
- # clamp and round
846
- im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
847
-
848
- self.lq, self.gt, self.txt = im_lq, im_gt, data['txt']
849
- if "gt_moment" not in data:
850
- self.gt_latent = self.encode_first_stage(
851
- im_gt.cuda(),
852
- center_input_sample=True,
853
- deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
854
- )
855
- else:
856
- self.gt_latent = self.encode_from_moment(
857
- data['gt_moment'].cuda(),
858
- deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
859
- )
860
-
861
- if (not self.configs.train.use_text) or self.configs.data.train.params.random_crop:
862
- self.txt = [_positive,] * im_lq.shape[0]
863
-
864
- # training pair pool
865
- self._dequeue_and_enqueue()
866
- self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
867
-
868
- batch = {'lq':self.lq, 'gt':self.gt, 'gt_latent':self.gt_latent, 'txt':self.txt}
869
- elif phase == 'val':
870
- resolution = self.configs.data.train.params.gt_size // self.configs.degradation.sf
871
- batch = {}
872
- batch['lq'] = data['lq'].cuda()
873
- if 'gt' in data:
874
- batch['gt'] = data['gt'].cuda()
875
- batch['txt'] = [_positive, ] * data['lq'].shape[0]
876
- else:
877
- batch = {key:value.cuda().to(dtype=torch.float32) for key, value in data.items()}
878
-
879
- return batch
880
-
881
- @torch.no_grad()
882
- def encode_from_moment(self, z, deterministic=True):
883
- dist = DiagonalGaussianDistribution(z)
884
- init_latents = dist.mode() if deterministic else dist.sample()
885
-
886
- latents_mean = latents_std = None
887
- if hasattr(self.sd_pipe.vae.config, "latents_mean") and self.sd_pipe.vae.config.latents_mean is not None:
888
- latents_mean = torch.tensor(self.sd_pipe.vae.config.latents_mean).view(1, 4, 1, 1)
889
- if hasattr(self.sd_pipe.vae.config, "latents_std") and self.sd_pipe.vae.config.latents_std is not None:
890
- latents_std = torch.tensor(self.sd_pipe.vae.config.latents_std).view(1, 4, 1, 1)
891
-
892
- scaling_factor = self.sd_pipe.vae.config.scaling_factor
893
- if latents_mean is not None and latents_std is not None:
894
- latents_mean = latents_mean.to(device=z.device, dtype=z.dtype)
895
- latents_std = latents_std.to(device=z.device, dtype=z.dtype)
896
- init_latents = (init_latents - latents_mean) * scaling_factor / latents_std
897
- else:
898
- init_latents = init_latents * scaling_factor
899
-
900
- return init_latents
901
-
902
- @torch.no_grad()
903
- @torch.amp.autocast('cuda')
904
- def encode_first_stage(self, x, deterministic=False, center_input_sample=True):
905
- if center_input_sample:
906
- x = x * 2.0 - 1.0
907
- latents_mean = latents_std = None
908
- if hasattr(self.sd_pipe.vae.config, "latents_mean") and self.sd_pipe.vae.config.latents_mean is not None:
909
- latents_mean = torch.tensor(self.sd_pipe.vae.config.latents_mean).view(1, -1, 1, 1)
910
- if hasattr(self.sd_pipe.vae.config, "latents_std") and self.sd_pipe.vae.config.latents_std is not None:
911
- latents_std = torch.tensor(self.sd_pipe.vae.config.latents_std).view(1, -1, 1, 1)
912
-
913
- if deterministic:
914
- partial_encode = lambda xx: self.sd_pipe.vae.encode(xx).latent_dist.mode()
915
- else:
916
- partial_encode = lambda xx: self.sd_pipe.vae.encode(xx).latent_dist.sample()
917
-
918
- trunk_size = self.configs.sd_pipe.vae_split
919
- if trunk_size < x.shape[0]:
920
- init_latents = torch.cat([partial_encode(xx) for xx in x.split(trunk_size, 0)], dim=0)
921
- else:
922
- init_latents = partial_encode(x)
923
-
924
- scaling_factor = self.sd_pipe.vae.config.scaling_factor
925
- if latents_mean is not None and latents_std is not None:
926
- latents_mean = latents_mean.to(device=x.device, dtype=x.dtype)
927
- latents_std = latents_std.to(device=x.device, dtype=x.dtype)
928
- init_latents = (init_latents - latents_mean) * scaling_factor / latents_std
929
- else:
930
- init_latents = init_latents * scaling_factor
931
-
932
- return init_latents
933
-
934
- @torch.no_grad()
935
- @torch.amp.autocast('cuda')
936
- def decode_first_stage(self, z, clamp=True):
937
- z = z / self.sd_pipe.vae.config.scaling_factor
938
-
939
- trunk_size = 1
940
- if trunk_size < z.shape[0]:
941
- out = torch.cat(
942
- [self.sd_pipe.vae.decode(xx).sample for xx in z.split(trunk_size, 0)], dim=0,
943
- )
944
- else:
945
- out = self.sd_pipe.vae.decode(z).sample
946
- if clamp:
947
- out = out.clamp(-1.0, 1.0)
948
- return out
949
-
950
- def get_loss_from_discrimnator(self, logits_fake):
951
- if not (isinstance(logits_fake, list) or isinstance(logits_fake, tuple)):
952
- g_loss = -torch.mean(logits_fake, dim=list(range(1, logits_fake.ndim)))
953
- else:
954
- g_loss = -torch.mean(logits_fake[0], dim=list(range(1, logits_fake[0].ndim)))
955
- for current_logits in logits_fake[1:]:
956
- g_loss += -torch.mean(current_logits, dim=list(range(1, current_logits.ndim)))
957
- g_loss /= len(logits_fake)
958
-
959
- return g_loss
960
-
961
- def training_step(self, data):
962
- current_bs = data['gt'].shape[0]
963
- micro_bs = self.configs.train.microbatch
964
- num_grad_accumulate = math.ceil(current_bs / micro_bs)
965
-
966
- # grad zero
967
- self.model.zero_grad()
968
-
969
- # update generator
970
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
971
- self.freeze_model(self.discriminator) # freeze discriminator
972
- z0_pred_list = []
973
- tt_list = []
974
- prompt_embeds_list = []
975
- for jj in range(0, current_bs, micro_bs):
976
- micro_data = {key:value[jj:jj+micro_bs] for key, value in data.items()}
977
- last_batch = (jj+micro_bs >= current_bs)
978
- if last_batch or self.num_gpus <= 1:
979
- losses, z0_pred, zt_noisy, tt = self.backward_step(micro_data, num_grad_accumulate)
980
- else:
981
- with self.model.no_sync():
982
- losses, z0_pred, zt_noisy, tt = self.backward_step(micro_data, num_grad_accumulate)
983
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
984
- z0_pred_list.append(z0_pred.detach())
985
- tt_list.append(tt)
986
- prompt_embeds_list.append(self.prompt_embeds.detach())
987
-
988
- if self.configs.train.use_amp:
989
- self.amp_scaler.step(self.optimizer)
990
- self.amp_scaler.update()
991
- else:
992
- self.optimizer.step()
993
-
994
- # update discriminator
995
- if (self.configs.train.loss_coef.get('ldis', 0) > 0 and
996
- (self.current_iters < self.configs.train.dis_init_iterations
997
- or self.current_iters % self.configs.train.dis_update_freq == 0)
998
- ):
999
- # grad zero
1000
- self.unfreeze_model(self.discriminator) # update discriminator
1001
- self.discriminator.zero_grad()
1002
- for ii, jj in enumerate(range(0, current_bs, micro_bs)):
1003
- micro_data = {key:value[jj:jj+micro_bs] for key, value in data.items()}
1004
- last_batch = (jj+micro_bs >= current_bs)
1005
- target = micro_data['gt_latent']
1006
- inputs = z0_pred_list[ii]
1007
- if last_batch or self.num_gpus <= 1:
1008
- logits = self.dis_backward_step(target, inputs, tt_list[ii], prompt_embeds_list[ii])
1009
- else:
1010
- with self.discriminator.no_sync():
1011
- logits = self.dis_backward_step(
1012
- target, inputs, tt_list[ii], prompt_embeds_list[ii]
1013
- )
1014
-
1015
- # make logging
1016
- if self.current_iters % self.configs.train.dis_update_freq == 0 and self.rank == 0:
1017
- ndim = logits[0].ndim
1018
- losses['real'] = logits[0].detach().mean(dim=list(range(1, ndim)))
1019
- losses['fake'] = logits[1].detach().mean(dim=list(range(1, ndim)))
1020
-
1021
- if self.configs.train.use_amp:
1022
- self.amp_scaler_dis.step(self.optimizer_dis)
1023
- self.amp_scaler_dis.update()
1024
- else:
1025
- self.optimizer_dis.step()
1026
-
1027
- # make logging
1028
- if self.rank == 0:
1029
- self.log_step_train(
1030
- losses, tt, micro_data, z0_pred, zt_noisy, z0_gt=micro_data['gt_latent'],
1031
- )
1032
-
1033
- @torch.no_grad()
1034
- def log_step_train(self, losses, tt, micro_data, z0_pred, zt_noisy, z0_gt=None, phase='train'):
1035
- '''
1036
- param losses: a dict recording the loss informations
1037
- '''
1038
- '''
1039
- param loss: a dict recording the loss informations
1040
- param micro_data: batch data
1041
- param tt: 1-D tensor, time steps
1042
- '''
1043
- if hasattr(self.configs.train, 'timesteps'):
1044
- if len(self.configs.train.timesteps) < 3:
1045
- record_steps = sorted(self.configs.train.timesteps)
1046
- else:
1047
- record_steps = [min(self.configs.train.timesteps),
1048
- max(self.configs.train.timesteps)]
1049
- else:
1050
- max_inference_steps = self.configs.train.max_inference_steps
1051
- record_steps = [1, max_inference_steps//2, max_inference_steps]
1052
- if ((self.current_iters // self.configs.train.dis_update_freq) %
1053
- (self.configs.train.log_freq[0] // self.configs.train.dis_update_freq) == 1):
1054
- self.loss_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64)
1055
- for key in losses.keys() if key not in ['real', 'fake']}
1056
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
1057
- self.logit_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64)
1058
- for key in ['real', 'fake']}
1059
- self.loss_count = torch.zeros(size=(len(record_steps),), dtype=torch.float64)
1060
- for jj in range(len(record_steps)):
1061
- for key, value in losses.items():
1062
- index = record_steps[jj] - 1
1063
- mask = torch.where(tt == index, torch.ones_like(tt), torch.zeros_like(tt))
1064
- assert value.shape == mask.shape
1065
- current_loss = torch.sum(value.detach() * mask)
1066
- if key in ['real', 'fake']:
1067
- self.logit_mean[key][jj] += current_loss.item()
1068
- else:
1069
- self.loss_mean[key][jj] += current_loss.item()
1070
- self.loss_count[jj] += mask.sum().item()
1071
-
1072
- if ((self.current_iters // self.configs.train.dis_update_freq) %
1073
- (self.configs.train.log_freq[0] // self.configs.train.dis_update_freq) == 0):
1074
- if torch.any(self.loss_count == 0):
1075
- self.loss_count += 1e-4
1076
- for key in losses.keys():
1077
- if key in ['real', 'fake']:
1078
- self.logit_mean[key] /= self.loss_count
1079
- else:
1080
- self.loss_mean[key] /= self.loss_count
1081
- log_str = f"Train: {self.current_iters:06d}/{self.configs.train.iterations:06d}, "
1082
- valid_keys = sorted([key for key in losses.keys() if key not in ['loss', 'real', 'fake']])
1083
- for ii, key in enumerate(valid_keys):
1084
- if ii == 0:
1085
- log_str += f"{key}"
1086
- else:
1087
- log_str += f"/{key}"
1088
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
1089
- log_str += "/real/fake:"
1090
- else:
1091
- log_str += ":"
1092
- for jj, current_record in enumerate(record_steps):
1093
- for ii, key in enumerate(valid_keys):
1094
- if ii == 0:
1095
- if key in ['dis', 'ldis']:
1096
- log_str += 't({:d}):{:+6.4f}'.format(
1097
- current_record,
1098
- self.loss_mean[key][jj].item(),
1099
- )
1100
- elif key in ['lpips', 'ldif']:
1101
- log_str += 't({:d}):{:4.2f}'.format(
1102
- current_record,
1103
- self.loss_mean[key][jj].item(),
1104
- )
1105
- elif key == 'llpips':
1106
- log_str += 't({:d}):{:5.3f}'.format(
1107
- current_record,
1108
- self.loss_mean[key][jj].item(),
1109
- )
1110
- else:
1111
- log_str += 't({:d}):{:.1e}'.format(
1112
- current_record,
1113
- self.loss_mean[key][jj].item(),
1114
- )
1115
- else:
1116
- if key in ['dis', 'ldis']:
1117
- log_str += f"/{self.loss_mean[key][jj].item():+6.4f}"
1118
- elif key in ['lpips', 'ldif']:
1119
- log_str += f"/{self.loss_mean[key][jj].item():4.2f}"
1120
- elif key == 'llpips':
1121
- log_str += f"/{self.loss_mean[key][jj].item():5.3f}"
1122
- else:
1123
- log_str += f"/{self.loss_mean[key][jj].item():.1e}"
1124
- if self.configs.train.loss_coef.get('ldis', 0) > 0:
1125
- log_str += f"/{self.logit_mean['real'][jj].item():+4.2f}"
1126
- log_str += f"/{self.logit_mean['fake'][jj].item():+4.2f}, "
1127
- else:
1128
- log_str += f", "
1129
- log_str += 'lr:{:.1e}'.format(self.optimizer.param_groups[0]['lr'])
1130
- self.logger.info(log_str)
1131
- self.logging_metric(self.loss_mean, tag='Loss', phase=phase, add_global_step=True)
1132
- if ((self.current_iters // self.configs.train.dis_update_freq) %
1133
- (self.configs.train.log_freq[1] // self.configs.train.dis_update_freq) == 0):
1134
- if zt_noisy is not None:
1135
- xt_pred = self.decode_first_stage(zt_noisy.detach())
1136
- self.logging_image(xt_pred, tag='xt-noisy', phase=phase, add_global_step=False)
1137
- if z0_pred is not None:
1138
- x0_pred = self.decode_first_stage(z0_pred.detach())
1139
- self.logging_image(x0_pred, tag='x0-pred', phase=phase, add_global_step=False)
1140
- if z0_gt is not None:
1141
- x0_recon = self.decode_first_stage(z0_gt.detach())
1142
- self.logging_image(x0_recon, tag='x0-recons', phase=phase, add_global_step=False)
1143
- if 'txt' in micro_data:
1144
- self.logging_text(micro_data['txt'], phase=phase)
1145
- self.logging_image(micro_data['lq'], tag='LQ', phase=phase, add_global_step=False)
1146
- self.logging_image(micro_data['gt'], tag='GT', phase=phase, add_global_step=True)
1147
-
1148
- if ((self.current_iters // self.configs.train.dis_update_freq) %
1149
- (self.configs.train.save_freq // self.configs.train.dis_update_freq) == 1):
1150
- self.tic = time.time()
1151
- if ((self.current_iters // self.configs.train.dis_update_freq) %
1152
- (self.configs.train.save_freq // self.configs.train.dis_update_freq) == 0):
1153
- self.toc = time.time()
1154
- elaplsed = (self.toc - self.tic)
1155
- self.logger.info(f"Elapsed time: {elaplsed:.2f}s")
1156
- self.logger.info("="*100)
1157
-
1158
- @torch.no_grad()
1159
- def validation(self, phase='val'):
1160
- torch.cuda.empty_cache()
1161
- if not (self.configs.validate.use_ema and hasattr(self.configs.train, 'ema_rate')):
1162
- self.model.eval()
1163
-
1164
- if self.configs.train.start_mode:
1165
- start_noise_predictor = self.ema_model if self.configs.validate.use_ema else self.model
1166
- intermediate_noise_predictor = None
1167
- else:
1168
- start_noise_predictor = self.start_model
1169
- intermediate_noise_predictor = self.ema_model if self.configs.validate.use_ema else self.model
1170
- num_iters_epoch = math.ceil(len(self.datasets[phase]) / self.configs.validate.batch)
1171
- mean_psnr = mean_lpips = 0
1172
- for jj, data in enumerate(self.dataloaders[phase]):
1173
- data = self.prepare_data(data, phase='val')
1174
- with torch.amp.autocast('cuda'):
1175
- xt_progressive, x0_progressive = self.sample(
1176
- image_lq=data['lq'],
1177
- prompt=[_positive,]*data['lq'].shape[0],
1178
- target_size=tuple(data['gt'].shape[-2:]),
1179
- start_noise_predictor=start_noise_predictor,
1180
- intermediate_noise_predictor=intermediate_noise_predictor,
1181
- )
1182
- x0 = xt_progressive[-1]
1183
- num_inference_steps = len(xt_progressive)
1184
-
1185
- if 'gt' in data:
1186
- if not hasattr(self, 'psnr_metric'):
1187
- self.psnr_metric = pyiqa.create_metric(
1188
- 'psnr',
1189
- test_y_channel=self.configs.train.get('val_y_channel', True),
1190
- color_space='ycbcr',
1191
- device=torch.device("cuda"),
1192
- )
1193
- if not hasattr(self, 'lpips_metric'):
1194
- self.lpips_metric = pyiqa.create_metric(
1195
- 'lpips-vgg',
1196
- device=torch.device("cuda"),
1197
- as_loss=False,
1198
- )
1199
- x0_normalize = util_image.normalize_th(x0, mean=0.5, std=0.5, reverse=True)
1200
- mean_psnr += self.psnr_metric(x0_normalize, data['gt']).sum().item()
1201
- with torch.amp.autocast('cuda'), torch.no_grad():
1202
- mean_lpips += self.lpips_metric(x0_normalize, data['gt']).sum().item()
1203
-
1204
- if (jj + 1) % self.configs.validate.log_freq == 0:
1205
- self.logger.info(f'Validation: {jj+1:02d}/{num_iters_epoch:02d}...')
1206
-
1207
- self.logging_image(data['gt'], tag='GT', phase=phase, add_global_step=False)
1208
- xt_progressive = rearrange(torch.cat(xt_progressive, dim=1), 'b (k c) h w -> (b k) c h w', c=3)
1209
- self.logging_image(
1210
- xt_progressive,
1211
- tag='sample-progress',
1212
- phase=phase,
1213
- add_global_step=False,
1214
- nrow=num_inference_steps,
1215
- )
1216
- x0_progressive = rearrange(torch.cat(x0_progressive, dim=1), 'b (k c) h w -> (b k) c h w', c=3)
1217
- self.logging_image(
1218
- x0_progressive,
1219
- tag='x0-progress',
1220
- phase=phase,
1221
- add_global_step=False,
1222
- nrow=num_inference_steps,
1223
- )
1224
- self.logging_image(data['lq'], tag='LQ', phase=phase, add_global_step=True)
1225
-
1226
- if 'gt' in data:
1227
- mean_psnr /= len(self.datasets[phase])
1228
- mean_lpips /= len(self.datasets[phase])
1229
- self.logger.info(f'Validation Metric: PSNR={mean_psnr:5.2f}, LPIPS={mean_lpips:6.4f}...')
1230
- self.logging_metric(mean_psnr, tag='PSNR', phase=phase, add_global_step=False)
1231
- self.logging_metric(mean_lpips, tag='LPIPS', phase=phase, add_global_step=True)
1232
-
1233
- self.logger.info("="*100)
1234
-
1235
- if not (self.configs.validate.use_ema and hasattr(self.configs.train, 'ema_rate')):
1236
- self.model.train()
1237
- torch.cuda.empty_cache()
1238
-
1239
- def backward_step(self, micro_data, num_grad_accumulate):
1240
- loss_coef = self.configs.train.loss_coef
1241
-
1242
- losses = {}
1243
- z0_gt = micro_data['gt_latent']
1244
- tt = torch.tensor(
1245
- random.choices(self.configs.train.timesteps, k=z0_gt.shape[0]),
1246
- dtype=torch.int64,
1247
- device=f"cuda:{self.rank}",
1248
- ) - 1
1249
-
1250
- with torch.autocast(device_type="cuda", enabled=self.configs.train.use_amp):
1251
- model_pred = self.model(
1252
- micro_data['lq'], tt, sample_posterior=False, center_input_sample=True,
1253
- )
1254
- z0_pred, zt_noisy_pred, z0_lq = self.sd_forward_step(
1255
- prompt=micro_data['txt'],
1256
- latents_hq=micro_data['gt_latent'],
1257
- image_lq=micro_data['lq'],
1258
- image_hq=micro_data['gt'],
1259
- model_pred=model_pred,
1260
- timesteps=tt,
1261
- )
1262
- # diffusion loss
1263
- if loss_coef.get('ldif', 0) > 0:
1264
- if self.configs.train.loss_type == 'L2':
1265
- ldif_loss = F.mse_loss(z0_pred, z0_gt, reduction='none')
1266
- elif self.configs.train.loss_type == 'L1':
1267
- ldif_loss = F.l1_loss(z0_pred, z0_gt, reduction='none')
1268
- else:
1269
- raise TypeError(f"Unsupported Loss type for Diffusion: {self.configs.train.loss_type}")
1270
- ldif_loss = torch.mean(ldif_loss, dim=list(range(1, z0_gt.ndim)))
1271
- losses['ldif'] = ldif_loss * loss_coef['ldif']
1272
- # Gaussian constraints
1273
- if loss_coef.get('kl', 0) > 0:
1274
- losses['kl'] = model_pred.kl() * loss_coef['kl']
1275
- if loss_coef.get('pkl', 0) > 0:
1276
- losses['pkl'] = model_pred.partial_kl() * loss_coef['pkl']
1277
- if loss_coef.get('rkl', 0) > 0:
1278
- other = Box(
1279
- {'mean': z0_gt-z0_lq,
1280
- 'var':torch.ones_like(z0_gt),
1281
- 'logvar':torch.zeros_like(z0_gt)}
1282
- )
1283
- losses['rkl'] = model_pred.kl(other) * loss_coef['rkl']
1284
- # discriminator loss
1285
- if loss_coef.get('ldis', 0) > 0:
1286
- if self.current_iters > self.configs.train.dis_init_iterations:
1287
- logits_fake = self.discriminator(
1288
- torch.clamp(z0_pred, min=_Latent_bound['min'], max=_Latent_bound['max']),
1289
- timestep=tt,
1290
- encoder_hidden_states=self.prompt_embeds,
1291
- )
1292
- losses['ldis'] = self.get_loss_from_discrimnator(logits_fake) * loss_coef['ldis']
1293
- else:
1294
- losses['ldis'] = torch.zeros((z0_gt.shape[0], ), dtype=torch.float32).cuda()
1295
- # perceptual loss
1296
- if loss_coef.get('llpips', 0) > 0:
1297
- losses['llpips'] = self.llpips_loss(z0_pred, z0_gt).view(-1) * loss_coef['llpips']
1298
-
1299
- for key in ['ldif', 'kl', 'rkl', 'pkl', 'ldis', 'llpips']:
1300
- if loss_coef.get(key, 0) > 0:
1301
- if not 'loss' in losses:
1302
- losses['loss'] = losses[key]
1303
- else:
1304
- losses['loss'] = losses['loss'] + losses[key]
1305
- loss = losses['loss'].mean() / num_grad_accumulate
1306
-
1307
- if self.amp_scaler is None:
1308
- loss.backward()
1309
- else:
1310
- self.amp_scaler.scale(loss).backward()
1311
-
1312
- return losses, z0_pred, zt_noisy_pred, tt
1313
-
1314
- def dis_backward_step(self, target, inputs, tt, prompt_embeds):
1315
- with torch.autocast(device_type="cuda", enabled=self.configs.train.use_amp):
1316
- logits_real = self.discriminator(target, tt, prompt_embeds)
1317
- inputs = inputs.clamp(min=_Latent_bound['min'], max=_Latent_bound['max'])
1318
- logits_fake = self.discriminator(inputs, tt, prompt_embeds)
1319
-
1320
- loss = hinge_d_loss(logits_real, logits_fake).mean()
1321
-
1322
- if self.amp_scaler_dis is None:
1323
- loss.backward()
1324
- else:
1325
- self.amp_scaler_dis.scale(loss).backward()
1326
-
1327
- return logits_real[-1], logits_fake[-1]
1328
-
1329
- def scale_sd_input(
1330
- self,
1331
- x:torch.Tensor,
1332
- sigmas: torch.Tensor = None,
1333
- timestep: torch.Tensor = None,
1334
- ) :
1335
- if sigmas is None:
1336
- if not self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1):
1337
- self.sd_pipe.scheduler = EulerDiscreteScheduler.from_pipe(
1338
- self.configs.sd_pipe.params.pretrained_model_name_or_path,
1339
- cache_dir=self.configs.sd_pipe.params.cache_dir,
1340
- subfolder='scheduler',
1341
- )
1342
- assert self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1)
1343
- sigmas = self.sd_pipe.scheduler.sigmas.flip(0).to(x.device)[timestep] # (b,)
1344
- sigmas = append_dims(sigmas, x.ndim)
1345
-
1346
- if sigmas.ndim < x.ndim:
1347
- sigmas = append_dims(sigmas, x.ndim)
1348
- out = x / ((sigmas**2 + 1) ** 0.5)
1349
- return out
1350
-
1351
- def prepare_lq_latents(
1352
- self,
1353
- image_lq: torch.Tensor,
1354
- timestep: torch.Tensor,
1355
- height: int = 512,
1356
- width: int = 512,
1357
- start_noise_predictor: torch.nn.Module = None,
1358
- ):
1359
- """
1360
- Input:
1361
- image_lq: low-quality image, torch.Tensor, range in [0, 1]
1362
- hight, width: resolution for high-quality image
1363
-
1364
- """
1365
- image_lq_up = F.interpolate(image_lq, size=(height, width), mode='bicubic')
1366
- init_latents = self.encode_first_stage(
1367
- image_lq_up, deterministic=False, center_input_sample=True,
1368
- )
1369
-
1370
- if start_noise_predictor is None:
1371
- model_pred = None
1372
- else:
1373
- model_pred = start_noise_predictor(
1374
- image_lq, timestep, sample_posterior=False, center_input_sample=True,
1375
- )
1376
-
1377
- # get latents
1378
- sigmas = self.sigmas_cache[timestep]
1379
- sigmas = append_dims(sigmas, init_latents.ndim)
1380
- latents = self.add_noise(init_latents, sigmas, model_pred)
1381
-
1382
- return latents
1383
-
1384
- def add_noise(self, latents, sigmas, model_pred=None):
1385
- if sigmas.ndim < latents.ndim:
1386
- sigmas = append_dims(sigmas, latents.ndim)
1387
-
1388
- if model_pred is None:
1389
- noise = torch.randn_like(latents)
1390
- zt_noisy = latents + sigmas * noise
1391
- else:
1392
- if self.configs.train.loss_coef.get('rkl', 0) > 0:
1393
- mean, std = model_pred.mean, model_pred.std
1394
- zt_noisy = latents + mean + sigmas * std * torch.randn_like(latents)
1395
- else:
1396
- zt_noisy = latents + sigmas * model_pred.sample()
1397
-
1398
- return zt_noisy
1399
-
1400
- def retrieve_timesteps(self):
1401
- device=torch.device(f"cuda:{self.rank}")
1402
-
1403
- num_inference_steps = self.configs.train.get('num_inference_steps', 5)
1404
- timesteps = np.linspace(
1405
- max(self.configs.train.timesteps), 0, num_inference_steps,
1406
- endpoint=False, dtype=np.int64,
1407
- ) - 1
1408
- timesteps = torch.from_numpy(timesteps).to(device)
1409
- self.sd_pipe.scheduler.timesteps = timesteps
1410
-
1411
- sigmas = self.sigmas_cache[timesteps.long()]
1412
- sigma_last = torch.tensor([0,], dtype=torch.float32).to(device=sigmas.device)
1413
- sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32)
1414
- self.sd_pipe.scheduler.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
1415
-
1416
- self.sd_pipe.scheduler._step_index = None
1417
- self.sd_pipe.scheduler._begin_index = None
1418
-
1419
- return self.sd_pipe.scheduler.timesteps, num_inference_steps
1420
-
1421
- class TrainerSDTurboSR(TrainerBaseSR):
1422
- def sd_forward_step(
1423
- self,
1424
- prompt: Union[str, List[str]] = None,
1425
- latents_hq: Optional[torch.Tensor] = None,
1426
- image_lq: torch.Tensor = None,
1427
- image_hq: torch.Tensor = None,
1428
- model_pred: DiagonalGaussianDistribution = None,
1429
- timesteps: List[int] = None,
1430
- **kwargs,
1431
- ):
1432
- r"""
1433
- Function invoked when calling the pipeline for generation.
1434
-
1435
- Args:
1436
- prompt (`str` or `List[str]`, *optional*):
1437
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1438
- instead.
1439
- image_lq (`torch.Tensor`): The low-quality image(s) for enhancement, range in [0, 1].
1440
- image_hq (`torch.Tensor`): The high-quality image(s) for enhancement, range in [0, 1].
1441
- noise_pred (`torch.Tensor`): Predicted noise by the noise prediction model
1442
- latents_hq (`torch.Tensor`, *optional*):
1443
- Pre-generated high-quality latents, sampled from a Gaussian distribution, to be used as inputs for image
1444
- generation. If not provided, a latents tensor will be generated by sampling using vae .
1445
- timesteps (`List[int]`, *optional*):
1446
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1447
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1448
- passed will be used. Must be in descending order.
1449
- aesthetic_score (`float`, *optional*, defaults to 6.0):
1450
- Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1451
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1452
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1453
- negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1454
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1455
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1456
- simulate an aesthetic score of the generated image by influencing the negative text condition.
1457
- """
1458
- device=torch.device(f"cuda:{self.rank}")
1459
- # Encode input prompt
1460
- prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
1461
- prompt=prompt,
1462
- device=device,
1463
- num_images_per_prompt=1,
1464
- do_classifier_free_guidance=False,
1465
- )
1466
- self.prompt_embeds = prompt_embeds
1467
-
1468
- # select the noise level, self.scheduler.sigmas, [1001,], descending
1469
- if not hasattr(self, 'sigmas_cache'):
1470
- assert self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1)
1471
- self.sigmas_cache = self.sd_pipe.scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000
1472
- sigmas = self.sigmas_cache[timesteps] # (b,)
1473
-
1474
- # Prepare input for SD
1475
- height, width = image_hq.shape[-2:]
1476
- if self.configs.train.start_mode:
1477
- image_lq_up = F.interpolate(image_lq, size=(height, width), mode='bicubic')
1478
- zt_clean = self.encode_first_stage(
1479
- image_lq_up, center_input_sample=True,
1480
- deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
1481
- )
1482
- else:
1483
- if latents_hq is None:
1484
- zt_clean = self.encode_first_stage(
1485
- image_hq, center_input_sample=True, deterministic=False,
1486
- )
1487
- else:
1488
- zt_clean = latents_hq
1489
-
1490
- sigmas = append_dims(sigmas, zt_clean.ndim)
1491
- zt_noisy = self.add_noise(zt_clean, sigmas, model_pred)
1492
-
1493
- prompt_embeds = prompt_embeds.to(device)
1494
-
1495
- zt_noisy_scale = self.scale_sd_input(zt_noisy, sigmas)
1496
- eps_pred = self.sd_pipe.unet(
1497
- zt_noisy_scale,
1498
- timesteps,
1499
- encoder_hidden_states=prompt_embeds,
1500
- timestep_cond=None,
1501
- cross_attention_kwargs=None,
1502
- added_cond_kwargs=None,
1503
- return_dict=False,
1504
- )[0] # eps-mode for sdxl and sdxl-refiner
1505
-
1506
- if self.configs.train.noise_detach:
1507
- z0_pred = zt_noisy.detach() - sigmas * eps_pred
1508
- else:
1509
- z0_pred = zt_noisy - sigmas * eps_pred
1510
-
1511
- return z0_pred, zt_noisy, zt_clean
1512
-
1513
- @torch.no_grad()
1514
- def sample(
1515
- self,
1516
- image_lq: torch.Tensor,
1517
- prompt: Union[str, List[str]] = None,
1518
- target_size: Tuple[int, int] = (1024, 1024),
1519
- start_noise_predictor: torch.nn.Module = None,
1520
- intermediate_noise_predictor: torch.nn.Module = None,
1521
- **kwargs,
1522
- ):
1523
- r"""
1524
- Function invoked when calling the pipeline for generation.
1525
-
1526
- Args:
1527
- prompt (`str` or `List[str]`, *optional*):
1528
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1529
- instead.
1530
- image_lq (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
1531
- The image(s) to modify with the pipeline.
1532
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1533
- The required height and width of the super-resolved image.
1534
- strength (`float`, *optional*, defaults to 0.3):
1535
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
1536
- will be used as a starting point, adding more noise to it the larger the `strength`. The number of
1537
- denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
1538
- be maximum and the denoising process will run for the full number of iterations specified in
1539
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
1540
- `denoising_start` being declared as an integer, the value of `strength` will be ignored.
1541
- num_inference_steps (`int`, *optional*, defaults to 50):
1542
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1543
- expense of slower inference.
1544
- timesteps (`List[int]`, *optional*):
1545
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1546
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1547
- passed will be used. Must be in descending order.
1548
- negative_prompt (`str` or `List[str]`, *optional*):
1549
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
1550
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1551
- less than `1`).
1552
- """
1553
- device=torch.device(f"cuda:{self.rank}")
1554
- batch_size = image_lq.shape[0]
1555
-
1556
- # Encode input prompt
1557
- prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
1558
- prompt=prompt,
1559
- device=device,
1560
- num_images_per_prompt=1,
1561
- do_classifier_free_guidance=False,
1562
- )
1563
-
1564
- timesteps, num_inference_steps = self.retrieve_timesteps()
1565
- latent_timestep = timesteps[:1].repeat(batch_size)
1566
-
1567
- # Prepare latent variables
1568
- height, width = target_size
1569
- latents = self.prepare_lq_latents(image_lq, latent_timestep.long(), height, width, start_noise_predictor)
1570
-
1571
- # Prepare extra step kwargs.
1572
- extra_step_kwargs = self.sd_pipe.prepare_extra_step_kwargs(None, 0.0)
1573
-
1574
- prompt_embeds = prompt_embeds.to(device)
1575
-
1576
- x0_progressive = []
1577
- images_progressive = []
1578
- for i, t in enumerate(timesteps):
1579
- latents_scaled = self.sd_pipe.scheduler.scale_model_input(latents, t)
1580
-
1581
- # predict the noise residual
1582
- eps_pred = self.sd_pipe.unet(
1583
- latents_scaled,
1584
- t,
1585
- encoder_hidden_states=prompt_embeds,
1586
- timestep_cond=None,
1587
- added_cond_kwargs=None,
1588
- return_dict=False,
1589
- )[0]
1590
- z0_pred = latents - self.sigmas_cache[t.long()] * eps_pred
1591
-
1592
- # compute the previous noisy sample x_t -> x_t-1
1593
- if intermediate_noise_predictor is not None and i + 1 < len(timesteps):
1594
- t_next = timesteps[i+1]
1595
- noise = intermediate_noise_predictor(image_lq, t_next, center_input_sample=True)
1596
- else:
1597
- noise = None
1598
- extra_step_kwargs['noise'] = noise
1599
- latents = self.sd_pipe.scheduler.step(eps_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1600
-
1601
- image = self.decode_first_stage(latents)
1602
- images_progressive.append(image)
1603
-
1604
- x0_pred = self.decode_first_stage(z0_pred)
1605
- x0_progressive.append(x0_pred)
1606
-
1607
- return images_progressive, x0_progressive
1608
-
1609
- def my_worker_init_fn(worker_id):
1610
- np.random.seed(np.random.get_state()[1][0] + worker_id)
1611
-
1612
- def hinge_d_loss(
1613
- logits_real: Union[torch.Tensor, List[torch.Tensor,]],
1614
- logits_fake: Union[torch.Tensor, List[torch.Tensor,]],
1615
- ):
1616
- def _hinge_d_loss(logits_real, logits_fake):
1617
- loss_real = F.relu(1.0 - logits_real)
1618
- loss_fake = F.relu(1.0 + logits_fake)
1619
- d_loss = 0.5 * (loss_real + loss_fake)
1620
- loss = d_loss.mean(dim=list(range(1, logits_real.ndim)))
1621
-
1622
- return loss
1623
-
1624
- if not (isinstance(logits_real, list) or isinstance(logits_real, tuple)):
1625
- loss = _hinge_d_loss(logits_real, logits_fake)
1626
- else:
1627
- loss = _hinge_d_loss(logits_real[0], logits_fake[0])
1628
- for xx, yy in zip(logits_real[1:], logits_fake[1:]):
1629
- loss += _hinge_d_loss(xx, yy)
1630
-
1631
- loss /= len(logits_real)
1632
-
1633
- return loss
1634
-
1635
- def get_torch_dtype(torch_dtype: str):
1636
- if torch_dtype == 'torch.float16':
1637
- return torch.float16
1638
- elif torch_dtype == 'torch.bfloat16':
1639
- return torch.bfloat16
1640
- elif torch_dtype == 'torch.float32':
1641
- return torch.float32
1642
- else:
1643
- raise ValueError(f'Unexpected torch dtype:{torch_dtype}')