|
from config import * |
|
from dataset import * |
|
import pandas as pd |
|
import json |
|
import os |
|
import copy |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
from pytorch_lightning import loggers as pl_loggers |
|
from pytorch_lightning.callbacks import * |
|
import torch |
|
|
|
|
|
class ZipLoader: |
|
def __init__(self, loaders): |
|
self.loaders = loaders |
|
|
|
def __len__(self): |
|
return len(self.loaders[0]) |
|
|
|
def __iter__(self): |
|
for each in zip(*self.loaders): |
|
yield each |
|
|
|
|
|
class ClsModel(pl.LightningModule): |
|
def __init__(self, conf: TrainConfig): |
|
super().__init__() |
|
assert conf.train_mode.is_manipulate() |
|
if conf.seed is not None: |
|
pl.seed_everything(conf.seed) |
|
|
|
self.save_hyperparameters(conf.as_dict_jsonable()) |
|
self.conf = conf |
|
|
|
|
|
if conf.train_mode == TrainMode.manipulate: |
|
|
|
|
|
|
|
self.model = conf.make_model_conf().make_model() |
|
self.ema_model = copy.deepcopy(self.model) |
|
self.model.requires_grad_(False) |
|
self.ema_model.requires_grad_(False) |
|
self.ema_model.eval() |
|
|
|
if conf.pretrain is not None: |
|
print(f'loading pretrain ... {conf.pretrain.name}') |
|
state = torch.load(conf.pretrain.path, map_location='cpu') |
|
print('step:', state['global_step']) |
|
self.load_state_dict(state['state_dict'], strict=False) |
|
|
|
|
|
if conf.manipulate_znormalize: |
|
print('loading latent stats ...') |
|
state = torch.load(conf.latent_infer_path) |
|
self.conds = state['conds'] |
|
self.register_buffer('conds_mean', |
|
state['conds_mean'][None, :]) |
|
self.register_buffer('conds_std', state['conds_std'][None, :]) |
|
else: |
|
self.conds_mean = None |
|
self.conds_std = None |
|
|
|
if conf.manipulate_mode in [ManipulateMode.celebahq_all]: |
|
num_cls = len(CelebAttrDataset.id_to_cls) |
|
elif conf.manipulate_mode.is_single_class(): |
|
num_cls = 1 |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
if conf.train_mode == TrainMode.manipulate: |
|
|
|
self.classifier = nn.Linear(conf.style_ch, num_cls) |
|
else: |
|
raise NotImplementedError() |
|
|
|
self.ema_classifier = copy.deepcopy(self.classifier) |
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
|
out = {} |
|
for k, v in super().state_dict(*args, **kwargs).items(): |
|
if k.startswith('model.'): |
|
pass |
|
elif k.startswith('ema_model.'): |
|
pass |
|
else: |
|
out[k] = v |
|
return out |
|
|
|
def load_state_dict(self, state_dict, strict: bool = None): |
|
if self.conf.train_mode == TrainMode.manipulate: |
|
|
|
if strict is None: |
|
strict = False |
|
else: |
|
if strict is None: |
|
strict = True |
|
return super().load_state_dict(state_dict, strict=strict) |
|
|
|
def normalize(self, cond): |
|
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( |
|
self.device) |
|
return cond |
|
|
|
def denormalize(self, cond): |
|
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( |
|
self.device) |
|
return cond |
|
|
|
def load_dataset(self): |
|
if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot: |
|
return CelebD2CAttrFewshotDataset( |
|
cls_name=self.conf.manipulate_cls, |
|
K=self.conf.manipulate_shots, |
|
img_folder=data_paths['celeba'], |
|
img_size=self.conf.img_size, |
|
seed=self.conf.manipulate_seed, |
|
all_neg=False, |
|
do_augment=True, |
|
) |
|
elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg: |
|
|
|
|
|
img_folder = data_paths['celeba'] |
|
|
|
return [ |
|
CelebD2CAttrFewshotDataset( |
|
cls_name=self.conf.manipulate_cls, |
|
K=self.conf.manipulate_shots, |
|
img_folder=img_folder, |
|
img_size=self.conf.img_size, |
|
only_cls_name=self.conf.manipulate_cls, |
|
only_cls_value=1, |
|
seed=self.conf.manipulate_seed, |
|
all_neg=True, |
|
do_augment=True), |
|
CelebD2CAttrFewshotDataset( |
|
cls_name=self.conf.manipulate_cls, |
|
K=self.conf.manipulate_shots, |
|
img_folder=img_folder, |
|
img_size=self.conf.img_size, |
|
only_cls_name=self.conf.manipulate_cls, |
|
only_cls_value=-1, |
|
seed=self.conf.manipulate_seed, |
|
all_neg=True, |
|
do_augment=True), |
|
] |
|
elif self.conf.manipulate_mode == ManipulateMode.celebahq_all: |
|
return CelebHQAttrDataset(data_paths['celebahq'], |
|
self.conf.img_size, |
|
data_paths['celebahq_anno'], |
|
do_augment=True) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def setup(self, stage=None) -> None: |
|
|
|
|
|
if self.conf.seed is not None: |
|
seed = self.conf.seed * get_world_size() + self.global_rank |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
print('local seed:', seed) |
|
|
|
|
|
self.train_data = self.load_dataset() |
|
if self.conf.manipulate_mode.is_fewshot(): |
|
|
|
if isinstance(self.train_data, list): |
|
|
|
|
|
a, b = self.train_data |
|
self.train_data = [ |
|
Repeat(a, max(len(a), len(b))), |
|
Repeat(b, max(len(a), len(b))), |
|
] |
|
else: |
|
self.train_data = Repeat(self.train_data, 100_000) |
|
|
|
def train_dataloader(self): |
|
|
|
|
|
conf = self.conf.clone() |
|
conf.batch_size = self.batch_size |
|
if isinstance(self.train_data, list): |
|
dataloader = [] |
|
for each in self.train_data: |
|
dataloader.append( |
|
conf.make_loader(each, shuffle=True, drop_last=True)) |
|
dataloader = ZipLoader(dataloader) |
|
else: |
|
dataloader = conf.make_loader(self.train_data, |
|
shuffle=True, |
|
drop_last=True) |
|
return dataloader |
|
|
|
@property |
|
def batch_size(self): |
|
ws = get_world_size() |
|
assert self.conf.batch_size % ws == 0 |
|
return self.conf.batch_size // ws |
|
|
|
def training_step(self, batch, batch_idx): |
|
self.ema_model: BeatGANsAutoencModel |
|
if isinstance(batch, tuple): |
|
a, b = batch |
|
imgs = torch.cat([a['img'], b['img']]) |
|
labels = torch.cat([a['labels'], b['labels']]) |
|
else: |
|
imgs = batch['img'] |
|
|
|
labels = batch['labels'] |
|
|
|
if self.conf.train_mode == TrainMode.manipulate: |
|
self.ema_model.eval() |
|
with torch.no_grad(): |
|
|
|
cond = self.ema_model.encoder(imgs) |
|
|
|
if self.conf.manipulate_znormalize: |
|
cond = self.normalize(cond) |
|
|
|
|
|
pred = self.classifier.forward(cond) |
|
pred_ema = self.ema_classifier.forward(cond) |
|
elif self.conf.train_mode == TrainMode.manipulate_img: |
|
|
|
pred = self.classifier.forward(imgs) |
|
pred_ema = None |
|
elif self.conf.train_mode == TrainMode.manipulate_imgt: |
|
t, weight = self.T_sampler.sample(len(imgs), imgs.device) |
|
imgs_t = self.sampler.q_sample(imgs, t) |
|
pred = self.classifier.forward(imgs_t, t=t) |
|
pred_ema = None |
|
print('pred:', pred.shape) |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.conf.manipulate_mode.is_celeba_attr(): |
|
gt = torch.where(labels > 0, |
|
torch.ones_like(labels).float(), |
|
torch.zeros_like(labels).float()) |
|
elif self.conf.manipulate_mode == ManipulateMode.relighting: |
|
gt = labels |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.conf.manipulate_loss == ManipulateLossType.bce: |
|
loss = F.binary_cross_entropy_with_logits(pred, gt) |
|
if pred_ema is not None: |
|
loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt) |
|
elif self.conf.manipulate_loss == ManipulateLossType.mse: |
|
loss = F.mse_loss(pred, gt) |
|
if pred_ema is not None: |
|
loss_ema = F.mse_loss(pred_ema, gt) |
|
else: |
|
raise NotImplementedError() |
|
|
|
self.log('loss', loss) |
|
self.log('loss_ema', loss_ema) |
|
return loss |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx: int, |
|
dataloader_idx: int) -> None: |
|
ema(self.classifier, self.ema_classifier, self.conf.ema_decay) |
|
|
|
def configure_optimizers(self): |
|
optim = torch.optim.Adam(self.classifier.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
return optim |
|
|
|
|
|
def ema(source, target, decay): |
|
source_dict = source.state_dict() |
|
target_dict = target.state_dict() |
|
for key in source_dict.keys(): |
|
target_dict[key].data.copy_(target_dict[key].data * decay + |
|
source_dict[key].data * (1 - decay)) |
|
|
|
|
|
def train_cls(conf: TrainConfig, gpus): |
|
print('conf:', conf.name) |
|
model = ClsModel(conf) |
|
|
|
if not os.path.exists(conf.logdir): |
|
os.makedirs(conf.logdir) |
|
checkpoint = ModelCheckpoint( |
|
dirpath=f'{conf.logdir}', |
|
save_last=True, |
|
save_top_k=1, |
|
|
|
|
|
) |
|
checkpoint_path = f'{conf.logdir}/last.ckpt' |
|
if os.path.exists(checkpoint_path): |
|
resume = checkpoint_path |
|
else: |
|
if conf.continue_from is not None: |
|
|
|
resume = conf.continue_from.path |
|
else: |
|
resume = None |
|
|
|
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, |
|
name=None, |
|
version='') |
|
|
|
|
|
|
|
plugins = [] |
|
if len(gpus) == 1: |
|
accelerator = None |
|
else: |
|
accelerator = 'ddp' |
|
from pytorch_lightning.plugins import DDPPlugin |
|
|
|
plugins.append(DDPPlugin(find_unused_parameters=False)) |
|
|
|
trainer = pl.Trainer( |
|
max_steps=conf.total_samples // conf.batch_size_effective, |
|
resume_from_checkpoint=resume, |
|
gpus=gpus, |
|
accelerator=accelerator, |
|
precision=16 if conf.fp16 else 32, |
|
callbacks=[ |
|
checkpoint, |
|
], |
|
replace_sampler_ddp=True, |
|
logger=tb_logger, |
|
accumulate_grad_batches=conf.accum_batches, |
|
plugins=plugins, |
|
) |
|
trainer.fit(model) |
|
|