Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,068 Bytes
0f079b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from dataclasses import dataclass
import torch
import torch.nn as nn
from typing import Optional
from diffusers.models.embeddings import Timesteps
import math
import craftsman
from craftsman.models.transformers.attention import ResidualAttentionBlock
from craftsman.models.transformers.utils import init_linear, MLP
from craftsman.utils.base import BaseModule
class UNetDiffusionTransformer(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = False,
skip_ln: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.encoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.encoder.append(resblock)
self.middle_block = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.decoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
linear = nn.Linear(width * 2, width)
init_linear(linear, init_scale)
layer_norm = nn.LayerNorm(width) if skip_ln else None
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
def forward(self, x: torch.Tensor):
enc_outputs = []
for block in self.encoder:
x = block(x)
enc_outputs.append(x)
x = self.middle_block(x)
for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
x = torch.cat([enc_outputs.pop(), x], dim=-1)
x = linear(x)
if layer_norm is not None:
x = layer_norm(x)
x = resblock(x)
return x
@craftsman.register("simple-denoiser")
class SimpleDenoiser(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: Optional[str] = None
input_channels: int = 32
output_channels: int = 32
n_ctx: int = 512
width: int = 768
layers: int = 6
heads: int = 12
context_dim: int = 1024
context_ln: bool = True
skip_ln: bool = False
init_scale: float = 0.25
flip_sin_to_cos: bool = False
use_checkpoint: bool = False
cfg: Config
def configure(self) -> None:
super().configure()
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width)
self.backbone = UNetDiffusionTransformer(
n_ctx=self.cfg.n_ctx,
width=self.cfg.width,
layers=self.cfg.layers,
heads=self.cfg.heads,
skip_ln=self.cfg.skip_ln,
init_scale=init_scale,
use_checkpoint=self.cfg.use_checkpoint
)
self.ln_post = nn.LayerNorm(self.cfg.width)
self.input_proj = nn.Linear(self.cfg.input_channels, self.cfg.width)
self.output_proj = nn.Linear(self.cfg.width, self.cfg.output_channels)
# timestep embedding
self.time_embed = Timesteps(self.cfg.width, flip_sin_to_cos=self.cfg.flip_sin_to_cos, downscale_freq_shift=0)
self.time_proj = MLP(width=self.cfg.width, init_scale=init_scale)
if self.cfg.context_ln:
self.context_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
else:
self.context_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
if self.cfg.pretrained_model_name_or_path:
pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")
_pretrained_ckpt = {}
for k, v in pretrained_ckpt.items():
if k.startswith('denoiser_model.'):
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v
pretrained_ckpt = _pretrained_ckpt
if 'state_dict' in pretrained_ckpt:
_pretrained_ckpt = {}
for k, v in pretrained_ckpt['state_dict'].items():
if k.startswith('denoiser_model.'):
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v
pretrained_ckpt = _pretrained_ckpt
self.load_state_dict(pretrained_ckpt, strict=True)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
_, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
# 2. conditions projector
context = self.context_embed(context)
# 3. denoiser
x = self.input_proj(model_input)
x = torch.cat([t_emb, context, x], dim=1)
x = self.backbone(x)
x = self.ln_post(x)
x = x[:, -n_data:] # B, n_data, width
sample = self.output_proj(x) # B, n_data, embed_dim
return sample |