CustomNet / customnet /customnet.py
jiangyzy's picture
Upload 35 files
28b27d8 verified
raw
history blame
5.21 kB
import os
import einops
import torch
import torch as th
import torch.nn as nn
import cv2
from pytorch_lightning.utilities.distributed import rank_zero_only
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import load_state_dict
class CustomNet(LatentDiffusion):
def __init__(self,
text_encoder_config,
sd_15_ckpt=None,
use_cond_concat=False,
use_bbox_mask=False,
use_bg_inpainting=False,
learning_rate_scale=10,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.text_encoder = instantiate_from_config(text_encoder_config)
if sd_15_ckpt is not None:
self.load_model_from_ckpt(ckpt=sd_15_ckpt)
self.use_cond_concat = use_cond_concat
self.use_bbox_mask = use_bbox_mask
self.use_bg_inpainting = use_bg_inpainting
self.learning_rate_scale = learning_rate_scale
def load_model_from_ckpt(self, ckpt, verbose=True):
print(" =========================== init Stable Diffusion pretrained checkpoint =========================== ")
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
sd_keys = sd.keys()
missing = []
text_encoder_sd = self.text_encoder.state_dict()
for k in text_encoder_sd.keys():
sd_k = "cond_stage_model."+ k
if sd_k in sd_keys:
text_encoder_sd[k] = sd[sd_k]
else:
missing.append(k)
self.text_encoder.load_state_dict(text_encoder_sd)
def configure_optimizers(self):
lr = self.learning_rate
params = []
params += list(self.cc_projection.parameters())
params_dualattn = []
for k, v in self.model.named_parameters():
if "to_k_text" in k or "to_v_text" in k:
params_dualattn.append(v)
print("training weight: ", k)
else:
params.append(v)
opt = torch.optim.AdamW([
{'params':params_dualattn, 'lr': lr*self.learning_rate_scale},
{'params': params, 'lr': lr}
])
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
def shared_step(self, batch, **kwargs):
if 'txt' in self.ucg_training:
k = 'txt'
p = self.ucg_training[k]
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1 - p, p]):
if isinstance(batch[k], list):
batch[k][i] = ""
with torch.no_grad():
text = batch['txt']
text_embedding = self.text_encoder(text)
x, c = self.get_input(batch, self.first_stage_key)
c["c_crossattn"].append(text_embedding)
loss = self(x, c,)
return loss
def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon