|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional |
|
from diffusers.models.embeddings import Timesteps |
|
import math |
|
|
|
import craftsman |
|
from craftsman.models.transformers.attention import ResidualAttentionBlock |
|
from craftsman.models.transformers.utils import init_linear, MLP |
|
from craftsman.utils.base import BaseModule |
|
|
|
|
|
class UNetDiffusionTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
n_ctx: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
init_scale: float = 0.25, |
|
qkv_bias: bool = False, |
|
skip_ln: bool = False, |
|
use_checkpoint: bool = False |
|
): |
|
super().__init__() |
|
|
|
self.n_ctx = n_ctx |
|
self.width = width |
|
self.layers = layers |
|
|
|
self.encoder = nn.ModuleList() |
|
for _ in range(layers): |
|
resblock = ResidualAttentionBlock( |
|
n_ctx=n_ctx, |
|
width=width, |
|
heads=heads, |
|
init_scale=init_scale, |
|
qkv_bias=qkv_bias, |
|
use_checkpoint=use_checkpoint |
|
) |
|
self.encoder.append(resblock) |
|
|
|
self.middle_block = ResidualAttentionBlock( |
|
n_ctx=n_ctx, |
|
width=width, |
|
heads=heads, |
|
init_scale=init_scale, |
|
qkv_bias=qkv_bias, |
|
use_checkpoint=use_checkpoint |
|
) |
|
|
|
self.decoder = nn.ModuleList() |
|
for _ in range(layers): |
|
resblock = ResidualAttentionBlock( |
|
n_ctx=n_ctx, |
|
width=width, |
|
heads=heads, |
|
init_scale=init_scale, |
|
qkv_bias=qkv_bias, |
|
use_checkpoint=use_checkpoint |
|
) |
|
linear = nn.Linear(width * 2, width) |
|
init_linear(linear, init_scale) |
|
|
|
layer_norm = nn.LayerNorm(width) if skip_ln else None |
|
|
|
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
enc_outputs = [] |
|
for block in self.encoder: |
|
x = block(x) |
|
enc_outputs.append(x) |
|
|
|
x = self.middle_block(x) |
|
|
|
for i, (resblock, linear, layer_norm) in enumerate(self.decoder): |
|
x = torch.cat([enc_outputs.pop(), x], dim=-1) |
|
x = linear(x) |
|
|
|
if layer_norm is not None: |
|
x = layer_norm(x) |
|
|
|
x = resblock(x) |
|
|
|
return x |
|
|
|
|
|
@craftsman.register("simple-denoiser") |
|
class SimpleDenoiser(BaseModule): |
|
|
|
@dataclass |
|
class Config(BaseModule.Config): |
|
pretrained_model_name_or_path: Optional[str] = None |
|
input_channels: int = 32 |
|
output_channels: int = 32 |
|
n_ctx: int = 512 |
|
width: int = 768 |
|
layers: int = 6 |
|
heads: int = 12 |
|
context_dim: int = 1024 |
|
context_ln: bool = True |
|
skip_ln: bool = False |
|
init_scale: float = 0.25 |
|
flip_sin_to_cos: bool = False |
|
use_checkpoint: bool = False |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
super().configure() |
|
|
|
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) |
|
|
|
self.backbone = UNetDiffusionTransformer( |
|
n_ctx=self.cfg.n_ctx, |
|
width=self.cfg.width, |
|
layers=self.cfg.layers, |
|
heads=self.cfg.heads, |
|
skip_ln=self.cfg.skip_ln, |
|
init_scale=init_scale, |
|
use_checkpoint=self.cfg.use_checkpoint |
|
) |
|
self.ln_post = nn.LayerNorm(self.cfg.width) |
|
self.input_proj = nn.Linear(self.cfg.input_channels, self.cfg.width) |
|
self.output_proj = nn.Linear(self.cfg.width, self.cfg.output_channels) |
|
|
|
|
|
self.time_embed = Timesteps(self.cfg.width, flip_sin_to_cos=self.cfg.flip_sin_to_cos, downscale_freq_shift=0) |
|
self.time_proj = MLP(width=self.cfg.width, init_scale=init_scale) |
|
|
|
if self.cfg.context_ln: |
|
self.context_embed = nn.Sequential( |
|
nn.LayerNorm(self.cfg.context_dim), |
|
nn.Linear(self.cfg.context_dim, self.cfg.width), |
|
) |
|
else: |
|
self.context_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) |
|
|
|
if self.cfg.pretrained_model_name_or_path: |
|
pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu") |
|
_pretrained_ckpt = {} |
|
for k, v in pretrained_ckpt.items(): |
|
if k.startswith('denoiser_model.'): |
|
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v |
|
pretrained_ckpt = _pretrained_ckpt |
|
if 'state_dict' in pretrained_ckpt: |
|
_pretrained_ckpt = {} |
|
for k, v in pretrained_ckpt['state_dict'].items(): |
|
if k.startswith('denoiser_model.'): |
|
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v |
|
pretrained_ckpt = _pretrained_ckpt |
|
self.load_state_dict(pretrained_ckpt, strict=True) |
|
|
|
def forward(self, |
|
model_input: torch.FloatTensor, |
|
timestep: torch.LongTensor, |
|
context: torch.FloatTensor): |
|
|
|
r""" |
|
Args: |
|
model_input (torch.FloatTensor): [bs, n_data, c] |
|
timestep (torch.LongTensor): [bs,] |
|
context (torch.FloatTensor): [bs, context_tokens, c] |
|
|
|
Returns: |
|
sample (torch.FloatTensor): [bs, n_data, c] |
|
|
|
""" |
|
|
|
_, n_data, _ = model_input.shape |
|
|
|
|
|
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) |
|
|
|
|
|
context = self.context_embed(context) |
|
|
|
|
|
x = self.input_proj(model_input) |
|
x = torch.cat([t_emb, context, x], dim=1) |
|
x = self.backbone(x) |
|
x = self.ln_post(x) |
|
x = x[:, -n_data:] |
|
sample = self.output_proj(x) |
|
|
|
return sample |