import os, math import torch import torch.nn.functional as F import pytorch_lightning as pl from main import instantiate_from_config from taming.modules.util import SOSProvider def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class Net2NetTransformer(pl.LightningModule): def __init__(self, transformer_config, first_stage_config, cond_stage_config, permuter_config=None, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", downsample_cond_size=-1, pkeep=1.0, sos_token=0, unconditional=False, ): super().__init__() self.be_unconditional = unconditional self.sos_token = sos_token self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) if permuter_config is None: permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} self.permuter = instantiate_from_config(config=permuter_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.downsample_cond_size = downsample_cond_size self.pkeep = pkeep def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] for k in sd.keys(): for ik in ignore_keys: if k.startswith(ik): self.print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") def init_first_stage_from_ckpt(self, config): model = instantiate_from_config(config) model = model.eval() model.train = disabled_train self.first_stage_model = model def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__" or self.be_unconditional: print(f"Using no cond stage. Assuming the training is intended to be unconditional. " f"Prepending {self.sos_token} as a sos token.") self.be_unconditional = True self.cond_stage_key = self.first_stage_key self.cond_stage_model = SOSProvider(self.sos_token) else: model = instantiate_from_config(config) model = model.eval() model.train = disabled_train self.cond_stage_model = model def forward(self, x, c): # one step to produce the logits # x = target # c = nucleus _, z_indices = self.encode_to_z(x) _, c_indices = self.encode_to_c(c) if self.training and self.pkeep < 1.0: mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, device=z_indices.device)) mask = mask.round().to(dtype=torch.int64) r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) a_indices = mask*z_indices+(1-mask)*r_indices else: a_indices = z_indices cz_indices = torch.cat((c_indices, a_indices), dim=1) # target includes all sequence elements (no need to handle first one # differently because we are conditioning) target = z_indices # make the prediction logits, _ = self.transformer(cz_indices[:, :-1]) # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) #quant_c, _, info = self.cond_stage_model.encode(x) #indices = info[2].view(quant_c.shape[0], -1) #indices = self.permuter(indices) quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) if len(indices.shape) != 2: indices = indices.view(c.shape[0], -1) return quant_c, indices @torch.no_grad() def decode_to_img(self, index, zshape): index = self.permuter(index, reverse=True) bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) quant_z = self.first_stage_model.quantize.get_codebook_entry( index.reshape(-1), shape=bhwc) x = self.first_stage_model.decode(quant_z) return x @torch.no_grad() def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): log = dict() N = 4 if lr_interface: x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) else: x, c = self.get_xc(batch, N) x = x.to(device=self.device) c = c.to(device=self.device) quant_z, z_indices = self.encode_to_z(x) quant_c, c_indices = self.encode_to_c(c) # create a "half"" sample z_start_indices = z_indices[:,:z_indices.shape[1]//2] index_sample = self.sample(z_start_indices, c_indices, steps=z_indices.shape[1]-z_start_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else 100, callback=callback if callback is not None else lambda k: None) x_sample = self.decode_to_img(index_sample, quant_z.shape) # sample z_start_indices = z_indices[:, :0] index_sample = self.sample(z_start_indices, c_indices, steps=z_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else 100, callback=callback if callback is not None else lambda k: None) x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) # det sample z_start_indices = z_indices[:, :0] index_sample = self.sample(z_start_indices, c_indices, steps=z_indices.shape[1], sample=False, callback=callback if callback is not None else lambda k: None) x_sample_det = self.decode_to_img(index_sample, quant_z.shape) # reconstruction x_rec = self.decode_to_img(z_indices, quant_z.shape) log["inputs"] = x log["reconstructions"] = x_rec if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target": cond_rec = self.cond_stage_model.decode(quant_c) if self.cond_stage_key == "segmentation": # get image from segmentation mask num_classes = cond_rec.shape[1] c = torch.argmax(c, dim=1, keepdim=True) c = F.one_hot(c, num_classes=num_classes) c = c.squeeze(1).permute(0, 3, 1, 2).float() c = self.cond_stage_model.to_rgb(c) cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) cond_rec = F.one_hot(cond_rec, num_classes=num_classes) cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() cond_rec = self.cond_stage_model.to_rgb(cond_rec) log["conditioning_rec"] = cond_rec log["conditioning"] = c log["samples_half"] = x_sample log["samples_nopix"] = x_sample_nopix log["samples_det"] = x_sample_det return log def get_input(self, key, batch): x = batch[key] if len(x.shape) == 3: x = x[..., None] #if len(x.shape) == 4: # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) if x.dtype == torch.double: x = x.float() return x def get_xc(self, batch, N=None): x = self.get_input(self.first_stage_key, batch) c = self.get_input(self.cond_stage_key, batch) if N is not None: x = x[:N] c = c[:N] return x, c def shared_step(self, batch): x, c = self.get_xc(batch) logits, target = self(x, c) loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) return loss def training_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) return loss def validation_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) return loss def configure_optimizers(self): """ Following minGPT: This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.transformer.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # special case the position embedding parameter in the root GPT module as not decayed no_decay.add('pos_emb') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.transformer.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) return optimizer