diffae / experiment_classifier.py
qninhdt's picture
Upload 68 files
1ab03a3 verified
from config import *
from dataset import *
import pandas as pd
import json
import os
import copy
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import *
import torch
class ZipLoader:
def __init__(self, loaders):
self.loaders = loaders
def __len__(self):
return len(self.loaders[0])
def __iter__(self):
for each in zip(*self.loaders):
yield each
class ClsModel(pl.LightningModule):
def __init__(self, conf: TrainConfig):
super().__init__()
assert conf.train_mode.is_manipulate()
if conf.seed is not None:
pl.seed_everything(conf.seed)
self.save_hyperparameters(conf.as_dict_jsonable())
self.conf = conf
# preparations
if conf.train_mode == TrainMode.manipulate:
# this is only important for training!
# the latent is freshly inferred to make sure it matches the image
# manipulating latents require the base model
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.model.requires_grad_(False)
self.ema_model.requires_grad_(False)
self.ema_model.eval()
if conf.pretrain is not None:
print(f'loading pretrain ... {conf.pretrain.name}')
state = torch.load(conf.pretrain.path, map_location='cpu')
print('step:', state['global_step'])
self.load_state_dict(state['state_dict'], strict=False)
# load the latent stats
if conf.manipulate_znormalize:
print('loading latent stats ...')
state = torch.load(conf.latent_infer_path)
self.conds = state['conds']
self.register_buffer('conds_mean',
state['conds_mean'][None, :])
self.register_buffer('conds_std', state['conds_std'][None, :])
else:
self.conds_mean = None
self.conds_std = None
if conf.manipulate_mode in [ManipulateMode.celebahq_all]:
num_cls = len(CelebAttrDataset.id_to_cls)
elif conf.manipulate_mode.is_single_class():
num_cls = 1
else:
raise NotImplementedError()
# classifier
if conf.train_mode == TrainMode.manipulate:
# latent manipluation requires only a linear classifier
self.classifier = nn.Linear(conf.style_ch, num_cls)
else:
raise NotImplementedError()
self.ema_classifier = copy.deepcopy(self.classifier)
def state_dict(self, *args, **kwargs):
# don't save the base model
out = {}
for k, v in super().state_dict(*args, **kwargs).items():
if k.startswith('model.'):
pass
elif k.startswith('ema_model.'):
pass
else:
out[k] = v
return out
def load_state_dict(self, state_dict, strict: bool = None):
if self.conf.train_mode == TrainMode.manipulate:
# change the default strict => False
if strict is None:
strict = False
else:
if strict is None:
strict = True
return super().load_state_dict(state_dict, strict=strict)
def normalize(self, cond):
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
self.device)
return cond
def denormalize(self, cond):
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
self.device)
return cond
def load_dataset(self):
if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot:
return CelebD2CAttrFewshotDataset(
cls_name=self.conf.manipulate_cls,
K=self.conf.manipulate_shots,
img_folder=data_paths['celeba'],
img_size=self.conf.img_size,
seed=self.conf.manipulate_seed,
all_neg=False,
do_augment=True,
)
elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg:
# positive-unlabeled classifier needs to keep the class ratio 1:1
# we use two dataloaders, one for each class, to stabiliize the training
img_folder = data_paths['celeba']
return [
CelebD2CAttrFewshotDataset(
cls_name=self.conf.manipulate_cls,
K=self.conf.manipulate_shots,
img_folder=img_folder,
img_size=self.conf.img_size,
only_cls_name=self.conf.manipulate_cls,
only_cls_value=1,
seed=self.conf.manipulate_seed,
all_neg=True,
do_augment=True),
CelebD2CAttrFewshotDataset(
cls_name=self.conf.manipulate_cls,
K=self.conf.manipulate_shots,
img_folder=img_folder,
img_size=self.conf.img_size,
only_cls_name=self.conf.manipulate_cls,
only_cls_value=-1,
seed=self.conf.manipulate_seed,
all_neg=True,
do_augment=True),
]
elif self.conf.manipulate_mode == ManipulateMode.celebahq_all:
return CelebHQAttrDataset(data_paths['celebahq'],
self.conf.img_size,
data_paths['celebahq_anno'],
do_augment=True)
else:
raise NotImplementedError()
def setup(self, stage=None) -> None:
##############################################
# NEED TO SET THE SEED SEPARATELY HERE
if self.conf.seed is not None:
seed = self.conf.seed * get_world_size() + self.global_rank
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
print('local seed:', seed)
##############################################
self.train_data = self.load_dataset()
if self.conf.manipulate_mode.is_fewshot():
# repeat the dataset to be larger (speed up the training)
if isinstance(self.train_data, list):
# fewshot-allneg has two datasets
# we resize them to be of equal sizes
a, b = self.train_data
self.train_data = [
Repeat(a, max(len(a), len(b))),
Repeat(b, max(len(a), len(b))),
]
else:
self.train_data = Repeat(self.train_data, 100_000)
def train_dataloader(self):
# make sure to use the fraction of batch size
# the batch size is global!
conf = self.conf.clone()
conf.batch_size = self.batch_size
if isinstance(self.train_data, list):
dataloader = []
for each in self.train_data:
dataloader.append(
conf.make_loader(each, shuffle=True, drop_last=True))
dataloader = ZipLoader(dataloader)
else:
dataloader = conf.make_loader(self.train_data,
shuffle=True,
drop_last=True)
return dataloader
@property
def batch_size(self):
ws = get_world_size()
assert self.conf.batch_size % ws == 0
return self.conf.batch_size // ws
def training_step(self, batch, batch_idx):
self.ema_model: BeatGANsAutoencModel
if isinstance(batch, tuple):
a, b = batch
imgs = torch.cat([a['img'], b['img']])
labels = torch.cat([a['labels'], b['labels']])
else:
imgs = batch['img']
# print(f'({self.global_rank}) imgs:', imgs.shape)
labels = batch['labels']
if self.conf.train_mode == TrainMode.manipulate:
self.ema_model.eval()
with torch.no_grad():
# (n, c)
cond = self.ema_model.encoder(imgs)
if self.conf.manipulate_znormalize:
cond = self.normalize(cond)
# (n, cls)
pred = self.classifier.forward(cond)
pred_ema = self.ema_classifier.forward(cond)
elif self.conf.train_mode == TrainMode.manipulate_img:
# (n, cls)
pred = self.classifier.forward(imgs)
pred_ema = None
elif self.conf.train_mode == TrainMode.manipulate_imgt:
t, weight = self.T_sampler.sample(len(imgs), imgs.device)
imgs_t = self.sampler.q_sample(imgs, t)
pred = self.classifier.forward(imgs_t, t=t)
pred_ema = None
print('pred:', pred.shape)
else:
raise NotImplementedError()
if self.conf.manipulate_mode.is_celeba_attr():
gt = torch.where(labels > 0,
torch.ones_like(labels).float(),
torch.zeros_like(labels).float())
elif self.conf.manipulate_mode == ManipulateMode.relighting:
gt = labels
else:
raise NotImplementedError()
if self.conf.manipulate_loss == ManipulateLossType.bce:
loss = F.binary_cross_entropy_with_logits(pred, gt)
if pred_ema is not None:
loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt)
elif self.conf.manipulate_loss == ManipulateLossType.mse:
loss = F.mse_loss(pred, gt)
if pred_ema is not None:
loss_ema = F.mse_loss(pred_ema, gt)
else:
raise NotImplementedError()
self.log('loss', loss)
self.log('loss_ema', loss_ema)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx: int,
dataloader_idx: int) -> None:
ema(self.classifier, self.ema_classifier, self.conf.ema_decay)
def configure_optimizers(self):
optim = torch.optim.Adam(self.classifier.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
return optim
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
def train_cls(conf: TrainConfig, gpus):
print('conf:', conf.name)
model = ClsModel(conf)
if not os.path.exists(conf.logdir):
os.makedirs(conf.logdir)
checkpoint = ModelCheckpoint(
dirpath=f'{conf.logdir}',
save_last=True,
save_top_k=1,
# every_n_train_steps=conf.save_every_samples //
# conf.batch_size_effective,
)
checkpoint_path = f'{conf.logdir}/last.ckpt'
if os.path.exists(checkpoint_path):
resume = checkpoint_path
else:
if conf.continue_from is not None:
# continue from a checkpoint
resume = conf.continue_from.path
else:
resume = None
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
name=None,
version='')
# from pytorch_lightning.
plugins = []
if len(gpus) == 1:
accelerator = None
else:
accelerator = 'ddp'
from pytorch_lightning.plugins import DDPPlugin
# important for working with gradient checkpoint
plugins.append(DDPPlugin(find_unused_parameters=False))
trainer = pl.Trainer(
max_steps=conf.total_samples // conf.batch_size_effective,
resume_from_checkpoint=resume,
gpus=gpus,
accelerator=accelerator,
precision=16 if conf.fp16 else 32,
callbacks=[
checkpoint,
],
replace_sampler_ddp=True,
logger=tb_logger,
accumulate_grad_batches=conf.accum_batches,
plugins=plugins,
)
trainer.fit(model)