Spaces:
Runtime error
Runtime error
import atexit | |
from collections import defaultdict | |
import logging | |
import typing | |
import torch | |
import time | |
from dp2.utils import vis_utils | |
from dp2 import utils | |
from tops import logger, checkpointer | |
import tops | |
from easydict import EasyDict | |
def accumulate_gradients(params, fp16_ddp_accumulate): | |
if len(params) == 0: | |
return | |
params = [param for param in params if param.grad is not None] | |
flat = torch.cat([param.grad.flatten() for param in params]) | |
orig_dtype = flat.dtype | |
if tops.world_size() > 1: | |
if fp16_ddp_accumulate: | |
flat = flat.half() / tops.world_size() | |
else: | |
flat /= tops.world_size() | |
torch.distributed.all_reduce(flat) | |
flat = flat.to(orig_dtype) | |
grads = flat.split([param.numel() for param in params]) | |
for param, grad in zip(params, grads): | |
param.grad = grad.reshape(param.shape) | |
def accumulate_buffers(module: torch.nn.Module): | |
buffers = [buf for buf in module.buffers()] | |
if len(buffers) == 0: | |
return | |
flat = torch.cat([buf.flatten() for buf in buffers]) | |
if tops.world_size() > 1: | |
torch.distributed.all_reduce(flat) | |
flat /= tops.world_size() | |
bufs = flat.split([buf.numel() for buf in buffers]) | |
for old, new in zip(buffers, bufs): | |
old.copy_(new.reshape(old.shape), non_blocking=True) | |
def check_ddp_consistency(module): | |
if tops.world_size() == 1: | |
return | |
assert isinstance(module, torch.nn.Module) | |
assert isinstance(module, torch.nn.Module) | |
params_buffs = list(module.named_parameters()) + list(module.named_buffers()) | |
for name, tensor in params_buffs: | |
fullname = type(module).__name__ + '.' + name | |
tensor = tensor.detach() | |
if tensor.is_floating_point(): | |
tensor = torch.nan_to_num(tensor) | |
other = tensor.clone() | |
torch.distributed.broadcast(tensor=other, src=0) | |
assert (tensor == other).all(), fullname | |
class AverageMeter(): | |
def __init__(self) -> None: | |
self.to_log = dict() | |
self.n = defaultdict(int) | |
pass | |
def update(self, values: dict): | |
for key, value in values.items(): | |
self.n[key] += 1 | |
if key in self.to_log: | |
self.to_log[key] += value.mean().detach() | |
else: | |
self.to_log[key] = value.mean().detach() | |
def get_average(self): | |
return {key: value / self.n[key] for key, value in self.to_log.items()} | |
class GANTrainer: | |
def __init__( | |
self, | |
G: torch.nn.Module, | |
D: torch.nn.Module, | |
G_EMA: torch.nn.Module, | |
D_optim: torch.optim.Optimizer, | |
G_optim: torch.optim.Optimizer, | |
dl_train: typing.Iterator, | |
dl_val: typing.Iterable, | |
scaler_D: torch.cuda.amp.GradScaler, | |
scaler_G: torch.cuda.amp.GradScaler, | |
ims_per_log: int, | |
max_images_to_train: int, | |
loss_handler, | |
ims_per_val: int, | |
evaluate_fn, | |
batch_size: int, | |
broadcast_buffers: bool, | |
fp16_ddp_accumulate: bool, | |
save_state: bool, | |
*args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.G = G | |
self.D = D | |
self.G_EMA = G_EMA | |
self.D_optim = D_optim | |
self.G_optim = G_optim | |
self.dl_train = dl_train | |
self.dl_val = dl_val | |
self.scaler_D = scaler_D | |
self.scaler_G = scaler_G | |
self.loss_handler = loss_handler | |
self.max_images_to_train = max_images_to_train | |
self.images_per_val = ims_per_val | |
self.images_per_log = ims_per_log | |
self.evaluate_fn = evaluate_fn | |
self.batch_size = batch_size | |
self.broadcast_buffers = broadcast_buffers | |
self.fp16_ddp_accumulate = fp16_ddp_accumulate | |
self.train_state = EasyDict( | |
next_log_step=0, | |
next_val_step=ims_per_val, | |
total_time=0 | |
) | |
checkpointer.register_models(dict( | |
generator=G, discriminator=D, EMA_generator=G_EMA, | |
D_optimizer=D_optim, | |
G_optimizer=G_optim, | |
train_state=self.train_state, | |
scaler_D=self.scaler_D, | |
scaler_G=self.scaler_G | |
)) | |
if checkpointer.has_checkpoint(): | |
checkpointer.load_registered_models() | |
logger.log(f"Resuming training from: global step: {logger.global_step()}") | |
else: | |
logger.add_dict({ | |
"stats/discriminator_parameters": tops.num_parameters(self.D), | |
"stats/generator_parameters": tops.num_parameters(self.G), | |
}, commit=False) | |
if save_state: | |
# If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint. | |
atexit.register(checkpointer.save_registered_models) | |
self._ims_per_log = ims_per_log | |
self.to_log = AverageMeter() | |
self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad] | |
self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad] | |
logger.add_dict({ | |
"stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D), | |
"stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G), | |
}, commit=False, level=logging.INFO) | |
check_ddp_consistency(self.D) | |
check_ddp_consistency(self.G) | |
check_ddp_consistency(self.G_EMA.generator) | |
def train_loop(self): | |
self.log_time() | |
while logger.global_step() <= self.max_images_to_train: | |
batch = next(self.dl_train) | |
self.G_EMA.update_beta() | |
self.to_log.update(self.step_D(batch)) | |
self.to_log.update(self.step_G(batch)) | |
self.G_EMA.update(self.G) | |
if logger.global_step() >= self.train_state.next_log_step: | |
to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()} | |
to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()}) | |
to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()}) | |
self.to_log = AverageMeter() | |
logger.add_dict(to_log, commit=True) | |
self.train_state.next_log_step += self.images_per_log | |
if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8: | |
print("Stopping training as gradient scale < 1e-8") | |
logger.log("Stopping training as gradient scale < 1e-8") | |
break | |
if logger.global_step() >= self.train_state.next_val_step: | |
self.evaluate() | |
self.log_time() | |
self.save_images() | |
self.train_state.next_val_step += self.images_per_val | |
logger.step(self.batch_size*tops.world_size()) | |
logger.log(f"Reached end of training at step {logger.global_step()}.") | |
checkpointer.save_registered_models() | |
def estimate_ims_per_hour(self): | |
batch = next(self.dl_train) | |
n_ims = int(100e3) | |
n_steps = int(n_ims / (self.batch_size * tops.world_size())) | |
n_ims = n_steps * self.batch_size * tops.world_size() | |
for i in range(10): # Warmup | |
self.G_EMA.update_beta() | |
self.step_D(batch) | |
self.step_G(batch) | |
self.G_EMA.update(self.G) | |
start_time = time.time() | |
for i in utils.tqdm_(list(range(n_steps))): | |
self.G_EMA.update_beta() | |
self.step_D(batch) | |
self.step_G(batch) | |
self.G_EMA.update(self.G) | |
total_time = time.time() - start_time | |
ims_per_sec = n_ims / total_time | |
ims_per_hour = ims_per_sec * 60*60 | |
ims_per_day = ims_per_hour * 24 | |
logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M") | |
logger.log(f"Images per day: {ims_per_day/1e6:.3f}M") | |
import math | |
ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4)) | |
logger.log(f"Images per 4 days: {ims_per_4_day}") | |
logger.add_dict({ | |
"stats/ims_per_day": ims_per_day, | |
"stats/ims_per_4_day": ims_per_4_day | |
}) | |
def log_time(self): | |
if not hasattr(self, "start_time"): | |
self.start_time = time.time() | |
self.last_time_step = logger.global_step() | |
return | |
n_images = logger.global_step() - self.last_time_step | |
if n_images == 0: | |
return | |
n_secs = time.time() - self.start_time | |
n_ims_per_sec = n_images / n_secs | |
training_time_hours = n_secs / 60 / 60 | |
self.train_state.total_time += training_time_hours | |
remaining_images = self.max_images_to_train - logger.global_step() | |
remaining_time = remaining_images / n_ims_per_sec / 60 / 60 | |
logger.add_dict({ | |
"stats/n_ims_per_sec": n_ims_per_sec, | |
"stats/total_traing_time_hours": self.train_state.total_time, | |
"stats/remaining_time_hours": remaining_time | |
}) | |
self.last_time_step = logger.global_step() | |
self.start_time = time.time() | |
def save_images(self): | |
dl_val = iter(self.dl_val) | |
batch = next(dl_val) | |
# TRUNCATED visualization | |
ims_to_log = 8 | |
self.G_EMA.eval() | |
z = self.G.get_z(batch["img"]) | |
fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"] | |
fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu() | |
if "__key__" in batch: | |
batch.pop("__key__") | |
real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] | |
to_vis = torch.cat((real, fakes_truncated)) | |
logger.add_images("images/truncated", to_vis, nrow=2) | |
# Diverse images | |
ims_diverse = 3 | |
batch = next(dl_val) | |
to_vis = [] | |
for i in range(ims_diverse): | |
z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1) | |
fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu() | |
to_vis.append(fakes) | |
if "__key__" in batch: | |
batch.pop("__key__") | |
reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] | |
to_vis.insert(0, reals) | |
to_vis = torch.cat(to_vis) | |
logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1) | |
self.G_EMA.train() | |
pass | |
def evaluate(self): | |
logger.log("Stating evaluation.") | |
self.G_EMA.eval() | |
try: | |
checkpointer.save_registered_models(max_keep=3) | |
except Exception: | |
logger.log("Could not save checkpoint.") | |
if self.broadcast_buffers: | |
check_ddp_consistency(self.G) | |
check_ddp_consistency(self.D) | |
metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val) | |
metrics = {f"metrics/{k}": v for k, v in metrics.items()} | |
logger.add_dict(metrics, level=logger.logger.INFO) | |
def step_D(self, batch): | |
utils.set_requires_grad(self.trainable_params_D, True) | |
utils.set_requires_grad(self.trainable_params_G, False) | |
tops.zero_grad(self.D) | |
loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D) | |
with torch.autograd.profiler.record_function("D_step"): | |
self.scaler_D.scale(loss).backward() | |
accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate) | |
if self.broadcast_buffers: | |
accumulate_buffers(self.D) | |
accumulate_buffers(self.G) | |
# Step will not unscale if unscale is called previously. | |
self.scaler_D.step(self.D_optim) | |
self.scaler_D.update() | |
utils.set_requires_grad(self.trainable_params_D, False) | |
utils.set_requires_grad(self.trainable_params_G, False) | |
return to_log | |
def step_G(self, batch): | |
utils.set_requires_grad(self.trainable_params_D, False) | |
utils.set_requires_grad(self.trainable_params_G, True) | |
tops.zero_grad(self.G) | |
loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G) | |
with torch.autograd.profiler.record_function("G_step"): | |
self.scaler_G.scale(loss).backward() | |
accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate) | |
if self.broadcast_buffers: | |
accumulate_buffers(self.G) | |
accumulate_buffers(self.D) | |
self.scaler_G.step(self.G_optim) | |
self.scaler_G.update() | |
utils.set_requires_grad(self.trainable_params_D, False) | |
utils.set_requires_grad(self.trainable_params_G, False) | |
return to_log | |