CraftsMan / craftsman /models /denoisers /simple_denoiser.py
wyysf's picture
i
0f079b2
raw
history blame
6.07 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
from typing import Optional
from diffusers.models.embeddings import Timesteps
import math
import craftsman
from craftsman.models.transformers.attention import ResidualAttentionBlock
from craftsman.models.transformers.utils import init_linear, MLP
from craftsman.utils.base import BaseModule
class UNetDiffusionTransformer(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = False,
skip_ln: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.encoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.encoder.append(resblock)
self.middle_block = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.decoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
linear = nn.Linear(width * 2, width)
init_linear(linear, init_scale)
layer_norm = nn.LayerNorm(width) if skip_ln else None
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
def forward(self, x: torch.Tensor):
enc_outputs = []
for block in self.encoder:
x = block(x)
enc_outputs.append(x)
x = self.middle_block(x)
for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
x = torch.cat([enc_outputs.pop(), x], dim=-1)
x = linear(x)
if layer_norm is not None:
x = layer_norm(x)
x = resblock(x)
return x
@craftsman.register("simple-denoiser")
class SimpleDenoiser(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: Optional[str] = None
input_channels: int = 32
output_channels: int = 32
n_ctx: int = 512
width: int = 768
layers: int = 6
heads: int = 12
context_dim: int = 1024
context_ln: bool = True
skip_ln: bool = False
init_scale: float = 0.25
flip_sin_to_cos: bool = False
use_checkpoint: bool = False
cfg: Config
def configure(self) -> None:
super().configure()
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width)
self.backbone = UNetDiffusionTransformer(
n_ctx=self.cfg.n_ctx,
width=self.cfg.width,
layers=self.cfg.layers,
heads=self.cfg.heads,
skip_ln=self.cfg.skip_ln,
init_scale=init_scale,
use_checkpoint=self.cfg.use_checkpoint
)
self.ln_post = nn.LayerNorm(self.cfg.width)
self.input_proj = nn.Linear(self.cfg.input_channels, self.cfg.width)
self.output_proj = nn.Linear(self.cfg.width, self.cfg.output_channels)
# timestep embedding
self.time_embed = Timesteps(self.cfg.width, flip_sin_to_cos=self.cfg.flip_sin_to_cos, downscale_freq_shift=0)
self.time_proj = MLP(width=self.cfg.width, init_scale=init_scale)
if self.cfg.context_ln:
self.context_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
else:
self.context_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
if self.cfg.pretrained_model_name_or_path:
pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")
_pretrained_ckpt = {}
for k, v in pretrained_ckpt.items():
if k.startswith('denoiser_model.'):
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v
pretrained_ckpt = _pretrained_ckpt
if 'state_dict' in pretrained_ckpt:
_pretrained_ckpt = {}
for k, v in pretrained_ckpt['state_dict'].items():
if k.startswith('denoiser_model.'):
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v
pretrained_ckpt = _pretrained_ckpt
self.load_state_dict(pretrained_ckpt, strict=True)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
_, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
# 2. conditions projector
context = self.context_embed(context)
# 3. denoiser
x = self.input_proj(model_input)
x = torch.cat([t_emb, context, x], dim=1)
x = self.backbone(x)
x = self.ln_post(x)
x = x[:, -n_data:] # B, n_data, width
sample = self.output_proj(x) # B, n_data, embed_dim
return sample