customdiffusion360's picture
first commit
ad7bc89
raw
history blame
21.7 kB
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union, DefaultDict
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange
import math
import torch.nn as nn
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
import collections
from functools import partial
def save_activations(
activations: DefaultDict,
name: str,
module: nn.Module,
inp: Tuple,
out: torch.Tensor
) -> None:
"""PyTorch Forward hook to save outputs at each forward
pass. Mutates specified dict objects with each fwd pass.
"""
if isinstance(out, tuple):
if out[1] is None:
activations[name].append(out[0].detach())
class DiffusionEngine(pl.LightningModule):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str] = None,
ckpt_path: Union[None, str] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "jpg",
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
trainkeys='pose',
multiplier=0.05,
loss_rgb_lambda=20.,
loss_fg_lambda=10.,
loss_bg_lambda=20.,
):
super().__init__()
self.log_keys = log_keys
self.input_key = input_key
self.trainkeys = trainkeys
self.multiplier = multiplier
self.loss_rgb_lambda = loss_rgb_lambda
self.loss_fg_lambda = loss_fg_lambda
self.loss_bg_lambda = loss_bg_lambda
self.rgb = network_config.params.rgb
self.rgb_predict = network_config.params.rgb_predict
self.add_token = ('modifier_token' in conditioner_config.params.emb_models[1].params)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
blocks = []
if self.trainkeys == 'poseattn':
for x in self.model.diffusion_model.named_parameters():
if not ('pose' in x[0] or 'transformer_blocks' in x[0]):
x[1].requires_grad = False
else:
if 'pose' in x[0]:
x[1].requires_grad = True
blocks.append(x[0].split('.pose')[0])
blocks = set(blocks)
for x in self.model.diffusion_model.named_parameters():
if 'transformer_blocks' in x[0]:
reqgrad = False
for each in blocks:
if each in x[0] and ('attn1' in x[0] or 'attn2' in x[0] or 'pose' in x[0]):
reqgrad = True
x[1].requires_grad = True
if not reqgrad:
x[1].requires_grad = False
elif self.trainkeys == 'pose':
for x in self.model.diffusion_model.named_parameters():
if not ('pose' in x[0]):
x[1].requires_grad = False
else:
x[1].requires_grad = True
elif self.trainkeys == 'all':
for x in self.model.diffusion_model.named_parameters():
x[1].requires_grad = True
self.model = self.model.to(memory_format=torch.channels_last)
def register_activation_hooks(
self,
) -> None:
self.activations_dict = collections.defaultdict(list)
handles = []
for name, module in self.model.diffusion_model.named_modules():
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
if hasattr(module, 'pose_emb_layers'):
handle = module.register_forward_hook(
partial(save_activations, self.activations_dict, name)
)
handles.append(handle)
self.handles = handles
def clear_rendered_feat(self,):
for name, module in self.model.diffusion_model.named_modules():
if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
if hasattr(module, 'pose_emb_layers'):
module.rendered_feat = None
def remove_activation_hooks(
self, handles
) -> None:
for handle in handles:
handle.remove()
def init_from_ckpt(
self,
path: str,
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def get_input(self, batch):
return batch[self.input_key], batch[self.input_key + '_ref'] if self.input_key + '_ref' in batch else None, batch['pose'] if 'pose' in batch else None, batch['mask'] if "mask" in batch else None, batch['mask_ref'] if "mask_ref" in batch else None, batch['depth'] if "depth" in batch else None, batch['drop_im'] if "drop_im" in batch else 0.
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z)
return out
@torch.no_grad()
def encode_first_stage(self, x):
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
z = self.first_stage_model.encode(x)
z = self.scale_factor * z
return z
def forward(self, x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch):
loss, loss_fg, loss_bg, loss_rgb = self.loss_fn(self.model, self.denoiser, self.conditioner, x, x_rgb, xr, pose, mask, mask_ref, opacity, batch)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean.item()}
if self.rgb and self.global_step > 0:
loss_fg = (loss_fg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
loss_bg = (loss_bg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
loss_mean += self.loss_fg_lambda*loss_fg
loss_mean += self.loss_bg_lambda*loss_bg
loss_dict["loss_fg"] = loss_fg.item()
loss_dict["loss_bg"] = loss_bg.item()
if self.rgb_predict and loss_rgb.mean() > 0:
loss_rgb = (loss_rgb.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
loss_mean += self.loss_rgb_lambda*loss_rgb
loss_dict["loss_rgb"] = loss_rgb.item()
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x, xr, pose, mask, mask_ref, opacity, drop_im = self.get_input(batch)
x_rgb = x.clone().detach()
x = self.encode_first_stage(x)
x = x.to(memory_format=torch.channels_last)
if xr is not None:
b, n = xr.shape[0], xr.shape[1]
xr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...")), "(b n) ... -> b n ...", b=b, n=n)
xr = drop_im.reshape(b, 1, 1, 1, 1)*xr + (1-drop_im.reshape(b, 1, 1, 1, 1))*torch.zeros_like(xr)
batch["global_step"] = self.global_step
loss, loss_dict = self(x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
def validation_step(self, batch, batch_idx):
# print("validation data", len(self.trainer.val_dataloaders))
loss, loss_dict = self.shared_step(batch)
return loss
def on_train_start(self, *args, **kwargs):
if self.sampler is None or self.loss_fn is None:
raise ValueError("Sampler and loss function need to be set for training.")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
optimizer.zero_grad(set_to_none=True)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self):
lr = self.learning_rate
params = []
blocks = []
lowlrparams = []
if self.trainkeys == 'poseattn':
lowlrparams = []
for x in self.model.diffusion_model.named_parameters():
if ('pose' in x[0]):
params += [x[1]]
blocks.append(x[0].split('.pose')[0])
print(x[0])
blocks = set(blocks)
for x in self.model.diffusion_model.named_parameters():
if 'transformer_blocks' in x[0]:
for each in blocks:
if each in x[0] and not ('pose' in x[0]) and ('attn1' in x[0] or 'attn2' in x[0]):
lowlrparams += [x[1]]
elif self.trainkeys == 'pose':
for x in self.model.diffusion_model.named_parameters():
if ('pose' in x[0]):
params += [x[1]]
print(x[0])
elif self.trainkeys == 'all':
lowlrparams = []
for x in self.model.diffusion_model.named_parameters():
if ('pose' in x[0]):
params += [x[1]]
print(x[0])
else:
lowlrparams += [x[1]]
for i, embedder in enumerate(self.conditioner.embedders[:2]):
if embedder.is_trainable:
params = params + list(embedder.parameters())
if self.add_token:
if i == 0:
for name, param in embedder.transformer.get_input_embeddings().named_parameters():
param.requires_grad = True
print(name, "conditional model param")
params += [param]
else:
for name, param in embedder.model.token_embedding.named_parameters():
param.requires_grad = True
print(name, "conditional model param")
params += [param]
if len(lowlrparams) > 0:
print("different optimizer groups")
opt = self.instantiate_optimizer_from_config([{'params': params}, {'params': lowlrparams, 'lr': self.multiplier*lr}], lr, self.optimizer_config)
else:
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
num_steps=None,
randn=None,
shape: Union[None, Tuple, List] = None,
return_rgb=False,
mask=None,
init_im=None,
**kwargs,
):
if randn is None:
randn = torch.randn(batch_size, *shape)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
if mask is not None:
samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, mask=mask, init_im=init_im, num_steps=num_steps)
else:
samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, num_steps=num_steps)
if return_rgb:
return samples, rgb_list
return samples
@torch.no_grad()
def samplemulti(
self,
cond,
uc=None,
batch_size: int = 16,
num_steps=None,
randn=None,
shape: Union[None, Tuple, List] = None,
return_rgb=False,
mask=None,
init_im=None,
multikwargs=None,
):
if randn is None:
randn = torch.randn(batch_size, *shape)
samples, rgb_list = self.sampler(self.denoiser, self.model, randn.to(self.device), cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs)
if return_rgb:
return samples, rgb_list
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int, refernce: bool = True) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[2:]
log = dict()
for embedder in self.conditioner.embedders:
if refernce:
check = (embedder.input_keys[0] in self.log_keys)
else:
check = (embedder.input_key in self.log_keys)
if (
(self.log_keys is None) or check
) and not self.no_cond_log:
if refernce:
x = batch[embedder.input_keys[0]][:n]
else:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
"x".join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
if refernce:
log[embedder.input_keys[0]] = xc
else:
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
log = dict()
x, xr, pose, mask, mask_ref, depth, drop_im = self.get_input(batch)
if xr is not None:
conditioner_input_keys = [e.input_keys for e in self.conditioner.embedders]
else:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
zr = None
if xr is not None:
xr = xr.to(self.device)[:N]
b, n = xr.shape[0], xr.shape[1]
log["reference"] = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)
zr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)), "(b n) ... -> b n ...", b=b, n=n)
log["inputs"] = x
b = x.shape[0]
if mask is not None:
log["mask"] = mask
if depth is not None:
log["depth"] = depth
z = self.encode_first_stage(x)
if uc is not None:
if xr is not None:
zr = torch.cat([torch.zeros_like(zr), zr])
drop_im = torch.cat([drop_im, drop_im])
if isinstance(pose, list):
pose = pose[:N]*2
else:
pose = torch.cat([pose[:N]] * 2)
sampling_kwargs = {'input_ref':zr}
sampling_kwargs['pose'] = pose
sampling_kwargs['mask_ref'] = None
sampling_kwargs['drop_im'] = drop_im
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N, refernce=True if xr is not None else False))
for k in c:
if isinstance(c[k], torch.Tensor):
if xr is not None:
c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to(self.device), (c, uc))
else:
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
if sample:
with self.ema_scope("Plotting"):
samples, rgb_list = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, return_rgb=True, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
if len(rgb_list) > 0:
size = int(math.sqrt(rgb_list[0].size(1)))
log["predicted_rgb"] = rgb_list[0].reshape(-1, size, size, 3).permute(0, 3, 1, 2)
return log