|
import os |
|
import numpy as np |
|
|
|
import torch |
|
import torch.random |
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
import pytorch_lightning as pl |
|
from pytorch_lightning import seed_everything |
|
from pytorch_lightning.trainer import Trainer |
|
|
|
from dataloader import CellLoader |
|
from celle import VQGanVAE, CELLE |
|
from omegaconf import OmegaConf |
|
import argparse, os, sys, datetime, glob |
|
|
|
from celle.celle import gumbel_sample, top_k |
|
|
|
torch.random.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
from celle_taming_main import ( |
|
instantiate_from_config, |
|
nondefault_trainer_args, |
|
get_parser, |
|
) |
|
|
|
|
|
class CellDataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
data_csv, |
|
dataset, |
|
sequence_mode="standard", |
|
vocab="bert", |
|
crop_size=256, |
|
resize=600, |
|
batch_size=1, |
|
threshold="median", |
|
text_seq_len=1000, |
|
num_workers=1, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.data_csv = data_csv |
|
self.dataset = dataset |
|
self.protein_sequence_length = 0 |
|
self.image_folders = [] |
|
self.crop_size = crop_size |
|
self.resize = resize |
|
self.batch_size = batch_size |
|
self.sequence_mode = sequence_mode |
|
self.threshold = threshold |
|
self.text_seq_len = int(text_seq_len) |
|
self.vocab = vocab |
|
self.num_workers = num_workers if num_workers is not None else batch_size * 2 |
|
|
|
def setup(self, stage=None): |
|
|
|
self.cell_dataset_train = CellLoader( |
|
data_csv=self.data_csv, |
|
dataset=self.dataset, |
|
crop_size=self.crop_size, |
|
resize=self.resize, |
|
split_key="train", |
|
crop_method="random", |
|
sequence_mode=self.sequence_mode, |
|
vocab=self.vocab, |
|
text_seq_len=self.text_seq_len, |
|
threshold=self.threshold, |
|
) |
|
|
|
self.cell_dataset_val = CellLoader( |
|
data_csv=self.data_csv, |
|
dataset=self.dataset, |
|
crop_size=self.crop_size, |
|
resize=self.resize, |
|
crop_method="center", |
|
split_key="val", |
|
sequence_mode=self.sequence_mode, |
|
vocab=self.vocab, |
|
text_seq_len=self.text_seq_len, |
|
threshold=self.threshold, |
|
) |
|
|
|
def prepare_data(self): |
|
|
|
pass |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.cell_dataset_train, |
|
num_workers=self.num_workers, |
|
shuffle=True, |
|
batch_size=self.batch_size, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.cell_dataset_val, |
|
num_workers=self.num_workers, |
|
batch_size=self.batch_size, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class CELLE_trainer(pl.LightningModule): |
|
def __init__( |
|
self, |
|
vqgan_model_path, |
|
vqgan_config_path, |
|
ckpt_path=None, |
|
image_key="threshold", |
|
condition_model_path=None, |
|
condition_config_path=None, |
|
num_images=2, |
|
dim=2, |
|
num_text_tokens=30, |
|
text_seq_len=1000, |
|
depth=16, |
|
heads=16, |
|
dim_head=64, |
|
attn_dropout=0.1, |
|
ff_dropout=0.1, |
|
attn_types="full", |
|
loss_img_weight=7, |
|
stable=False, |
|
rotary_emb=True, |
|
text_embedding="bert", |
|
fixed_embedding=True, |
|
loss_cond_weight=1, |
|
learning_rate=3e-4, |
|
monitor="val_loss", |
|
): |
|
super().__init__() |
|
|
|
vae = VQGanVAE( |
|
vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path |
|
) |
|
|
|
self.image_key = image_key |
|
|
|
if condition_config_path: |
|
condition_vae = VQGanVAE( |
|
vqgan_model_path=condition_model_path, |
|
vqgan_config_path=condition_config_path, |
|
) |
|
else: |
|
condition_vae = None |
|
|
|
self.celle = CELLE( |
|
dim=dim, |
|
vae=vae, |
|
condition_vae=condition_vae, |
|
num_images=num_images, |
|
num_text_tokens=num_text_tokens, |
|
text_seq_len=text_seq_len, |
|
depth=depth, |
|
heads=heads, |
|
dim_head=dim_head, |
|
attn_dropout=attn_dropout, |
|
ff_dropout=ff_dropout, |
|
loss_img_weight=loss_img_weight, |
|
stable=stable, |
|
rotary_emb=rotary_emb, |
|
text_embedding=text_embedding, |
|
fixed_embedding=fixed_embedding, |
|
loss_cond_weight=loss_cond_weight, |
|
) |
|
|
|
self.learning_rate = learning_rate |
|
self.num_text_tokens = num_text_tokens |
|
self.num_images = num_images |
|
|
|
if monitor is not None: |
|
self.monitor = monitor |
|
|
|
ignore_keys = [] |
|
|
|
if condition_model_path: |
|
ignore_keys.append("celle.condition_vae") |
|
|
|
if vqgan_model_path: |
|
ignore_keys.append("celle.vae") |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
ckpt = sd.copy() |
|
for k in sd.keys(): |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
|
|
del ckpt[k] |
|
self.load_state_dict(ckpt, strict=False) |
|
print(f"Restored from {path}") |
|
|
|
def forward(self, text, condition, target, return_loss=True): |
|
|
|
return self.celle( |
|
text=text, condition=condition, image=target, return_loss=return_loss |
|
) |
|
|
|
def get_input(self, batch): |
|
text = batch["sequence"].squeeze(1) |
|
condition = batch["nucleus"] |
|
target = batch[self.image_key] |
|
|
|
return text, condition, target |
|
|
|
def get_image_from_logits(self, logits, temperature=0.9): |
|
|
|
filtered_logits = top_k(logits, thres=0.5) |
|
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
|
|
|
self.celle.vae.eval() |
|
out = self.celle.vae.decode( |
|
sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :] |
|
- (self.celle.num_text_tokens + self.celle.num_condition_tokens) |
|
) |
|
|
|
return out |
|
|
|
def get_loss(self, text, condition, target): |
|
|
|
loss_dict = {} |
|
|
|
loss, loss_dict, logits = self(text, condition, target, return_loss=True) |
|
|
|
return loss, loss_dict |
|
|
|
def total_loss( |
|
self, |
|
loss, |
|
loss_dict, |
|
mode="train", |
|
): |
|
|
|
loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()} |
|
|
|
for key, value in loss_dict.items(): |
|
self.log( |
|
key, |
|
value, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=True, |
|
sync_dist=True, |
|
) |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
text, condition, target = self.get_input(batch) |
|
loss, log_dict = self.get_loss(text, condition, target) |
|
|
|
loss = self.total_loss(loss, log_dict, mode="train") |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
with torch.no_grad(): |
|
|
|
text, condition, target = self.get_input(batch) |
|
loss, log_dict = self.get_loss(text, condition, target) |
|
|
|
loss = self.total_loss(loss, log_dict, mode="val") |
|
|
|
return loss |
|
|
|
def configure_optimizers(self): |
|
|
|
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95)) |
|
|
|
return optimizer |
|
|
|
def scale_image(self, image): |
|
|
|
for tensor in image: |
|
if torch.min(tensor) < 0: |
|
tensor += -torch.min(tensor) |
|
else: |
|
tensor -= torch.min(tensor) |
|
|
|
tensor /= torch.max(tensor) |
|
|
|
return image |
|
|
|
@torch.no_grad() |
|
def log_images(self, batch, **kwargs): |
|
|
|
log = [] |
|
|
|
text, condition, target = self.get_input(batch) |
|
text = text.squeeze(1).to(self.device) |
|
condition = condition.to(self.device) |
|
|
|
out = self.celle.generate_images(text=text, condition=condition) |
|
|
|
log["condition"] = self.scale_image(condition) |
|
log["output"] = self.scale_image(out) |
|
if self.image_key == "threshold": |
|
log["threshold"] = self.scale_image(target) |
|
log["target"] = self.scale_image(batch["target"]) |
|
else: |
|
log["target"] = self.scale_image(target) |
|
|
|
return log |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
|
|
|
|
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
parser = get_parser() |
|
parser = Trainer.add_argparse_args(parser) |
|
|
|
opt, unknown = parser.parse_known_args() |
|
if opt.name and opt.resume: |
|
raise ValueError( |
|
"-n/--name and -r/--resume cannot be specified both." |
|
"If you want to resume training in a new log folder, " |
|
"use -n/--name in combination with --resume_from_checkpoint" |
|
) |
|
if opt.resume: |
|
if not os.path.exists(opt.resume): |
|
raise ValueError("Cannot find {}".format(opt.resume)) |
|
if os.path.isfile(opt.resume): |
|
paths = opt.resume.split("/") |
|
idx = len(paths) - paths[::-1].index("logs") + 1 |
|
logdir = "/".join(paths[:idx]) |
|
ckpt = opt.resume |
|
else: |
|
assert os.path.isdir(opt.resume), opt.resume |
|
logdir = opt.resume.rstrip("/") |
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
|
|
opt.resume_from_checkpoint = ckpt |
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) |
|
opt.base = base_configs + opt.base |
|
_tmp = logdir.split("/") |
|
nowname = _tmp[_tmp.index("logs") + 1] |
|
else: |
|
if opt.name: |
|
name = "_" + opt.name |
|
elif opt.base: |
|
cfg_fname = os.path.split(opt.base[0])[-1] |
|
cfg_name = os.path.splitext(cfg_fname)[0] |
|
name = "_" + cfg_name |
|
else: |
|
name = "" |
|
nowname = now + name + opt.postfix |
|
logdir = os.path.join("logs", nowname) |
|
|
|
ckptdir = os.path.join(logdir, "checkpoints") |
|
cfgdir = os.path.join(logdir, "configs") |
|
seed_everything(opt.seed) |
|
|
|
try: |
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base] |
|
cli = OmegaConf.from_dotlist(unknown) |
|
config = OmegaConf.merge(*configs, cli) |
|
lightning_config = config.pop("lightning", OmegaConf.create()) |
|
|
|
trainer_config = lightning_config.get("trainer", OmegaConf.create()) |
|
|
|
|
|
for k in nondefault_trainer_args(opt): |
|
trainer_config[k] = getattr(opt, k) |
|
if not "gpus" in trainer_config: |
|
del trainer_config["distributed_backend"] |
|
cpu = True |
|
else: |
|
gpuinfo = trainer_config["gpus"] |
|
print(f"Running on GPUs {gpuinfo}") |
|
cpu = False |
|
trainer_opt = argparse.Namespace(**trainer_config) |
|
lightning_config.trainer = trainer_config |
|
|
|
|
|
|
|
model = instantiate_from_config(config.model) |
|
|
|
trainer_kwargs = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
default_logger_cfgs = { |
|
"wandb": { |
|
"target": "pytorch_lightning.loggers.WandbLogger", |
|
"params": { |
|
"name": nowname, |
|
"save_dir": logdir, |
|
"offline": opt.debug, |
|
"id": nowname, |
|
}, |
|
}, |
|
"testtube": { |
|
|
|
"target": "pytorch_lightning.loggers.TensorBoardLogger", |
|
"params": { |
|
"name": "testtube", |
|
"save_dir": logdir, |
|
}, |
|
}, |
|
} |
|
default_logger_cfg = default_logger_cfgs["testtube"] |
|
|
|
try: |
|
logger_cfg = lightning_config.logger |
|
except: |
|
logger_cfg = OmegaConf.create() |
|
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) |
|
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) |
|
|
|
|
|
|
|
default_modelckpt_cfg = { |
|
"checkpoint_callback": { |
|
"target": "pytorch_lightning.callbacks.ModelCheckpoint", |
|
"params": { |
|
"dirpath": ckptdir, |
|
"filename": "{epoch:06}", |
|
"verbose": True, |
|
"save_last": True, |
|
}, |
|
} |
|
} |
|
if hasattr(model, "monitor"): |
|
print(f"Monitoring {model.monitor} as checkpoint metric.") |
|
default_modelckpt_cfg["checkpoint_callback"]["params"][ |
|
"monitor" |
|
] = model.monitor |
|
default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3 |
|
try: |
|
modelckpt_cfg = lightning_config.modelcheckpoint |
|
except: |
|
modelckpt_cfg = OmegaConf.create() |
|
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) |
|
|
|
|
|
|
|
default_callbacks_cfg = { |
|
"setup_callback": { |
|
"target": "celle_taming_main.SetupCallback", |
|
"params": { |
|
"resume": opt.resume, |
|
"now": now, |
|
"logdir": logdir, |
|
"ckptdir": ckptdir, |
|
"cfgdir": cfgdir, |
|
"config": config, |
|
"lightning_config": lightning_config, |
|
}, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
try: |
|
callbacks_cfg = lightning_config.callbacks |
|
except: |
|
callbacks_cfg = OmegaConf.create() |
|
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) |
|
callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg) |
|
trainer_kwargs["callbacks"] = [ |
|
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg |
|
] |
|
|
|
trainer = Trainer.from_argparse_args( |
|
trainer_opt, **trainer_kwargs, profiler="simple" |
|
) |
|
|
|
|
|
data = instantiate_from_config(config.data) |
|
|
|
|
|
|
|
data.setup() |
|
data.prepare_data() |
|
|
|
|
|
bs, lr = config.data.params.batch_size, config.model.learning_rate |
|
|
|
if not cpu: |
|
ngpu = len(lightning_config.trainer.gpus.strip(",").split(",")) |
|
else: |
|
ngpu = 1 |
|
try: |
|
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches |
|
except: |
|
accumulate_grad_batches = 1 |
|
print(f"accumulate_grad_batches = {accumulate_grad_batches}") |
|
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches |
|
model.learning_rate = accumulate_grad_batches * ngpu * bs * lr |
|
|
|
print( |
|
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format( |
|
model.learning_rate, accumulate_grad_batches, ngpu, bs, lr |
|
) |
|
) |
|
|
|
|
|
def melk(*args, **kwargs): |
|
|
|
if trainer.global_rank == 0: |
|
print("Summoning checkpoint.") |
|
ckpt_path = os.path.join(ckptdir, "last.ckpt") |
|
trainer.save_checkpoint(ckpt_path) |
|
|
|
def divein(*args, **kwargs): |
|
if trainer.global_rank == 0: |
|
import pudb |
|
|
|
pudb.set_trace() |
|
|
|
import signal |
|
|
|
signal.signal(signal.SIGUSR1, melk) |
|
signal.signal(signal.SIGUSR2, divein) |
|
|
|
|
|
if opt.train: |
|
try: |
|
|
|
torch.compile(trainer.fit(model, data), mode="max-autotune") |
|
except Exception: |
|
melk() |
|
raise |
|
if not opt.no_test and not trainer.interrupted: |
|
trainer.test(model, data) |
|
except Exception: |
|
if opt.debug and trainer.global_rank == 0: |
|
try: |
|
import pudb as debugger |
|
except ImportError: |
|
import pdb as debugger |
|
debugger.post_mortem() |
|
raise |
|
finally: |
|
|
|
if opt.debug and not opt.resume and trainer.global_rank == 0: |
|
dst, name = os.path.split(logdir) |
|
dst = os.path.join(dst, "debug_runs", name) |
|
os.makedirs(os.path.split(dst)[0], exist_ok=True) |
|
os.rename(logdir, dst) |
|
|