Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
from typing import Optional | |
from diffusers.models.embeddings import Timesteps | |
import math | |
import craftsman | |
from craftsman.models.transformers.attention import ResidualAttentionBlock | |
from craftsman.models.transformers.utils import init_linear, MLP | |
from craftsman.utils.base import BaseModule | |
class UNetDiffusionTransformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
n_ctx: int, | |
width: int, | |
layers: int, | |
heads: int, | |
init_scale: float = 0.25, | |
qkv_bias: bool = False, | |
skip_ln: bool = False, | |
use_checkpoint: bool = False | |
): | |
super().__init__() | |
self.n_ctx = n_ctx | |
self.width = width | |
self.layers = layers | |
self.encoder = nn.ModuleList() | |
for _ in range(layers): | |
resblock = ResidualAttentionBlock( | |
n_ctx=n_ctx, | |
width=width, | |
heads=heads, | |
init_scale=init_scale, | |
qkv_bias=qkv_bias, | |
use_checkpoint=use_checkpoint | |
) | |
self.encoder.append(resblock) | |
self.middle_block = ResidualAttentionBlock( | |
n_ctx=n_ctx, | |
width=width, | |
heads=heads, | |
init_scale=init_scale, | |
qkv_bias=qkv_bias, | |
use_checkpoint=use_checkpoint | |
) | |
self.decoder = nn.ModuleList() | |
for _ in range(layers): | |
resblock = ResidualAttentionBlock( | |
n_ctx=n_ctx, | |
width=width, | |
heads=heads, | |
init_scale=init_scale, | |
qkv_bias=qkv_bias, | |
use_checkpoint=use_checkpoint | |
) | |
linear = nn.Linear(width * 2, width) | |
init_linear(linear, init_scale) | |
layer_norm = nn.LayerNorm(width) if skip_ln else None | |
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) | |
def forward(self, x: torch.Tensor): | |
enc_outputs = [] | |
for block in self.encoder: | |
x = block(x) | |
enc_outputs.append(x) | |
x = self.middle_block(x) | |
for i, (resblock, linear, layer_norm) in enumerate(self.decoder): | |
x = torch.cat([enc_outputs.pop(), x], dim=-1) | |
x = linear(x) | |
if layer_norm is not None: | |
x = layer_norm(x) | |
x = resblock(x) | |
return x | |
class SimpleDenoiser(BaseModule): | |
class Config(BaseModule.Config): | |
pretrained_model_name_or_path: Optional[str] = None | |
input_channels: int = 32 | |
output_channels: int = 32 | |
n_ctx: int = 512 | |
width: int = 768 | |
layers: int = 6 | |
heads: int = 12 | |
context_dim: int = 1024 | |
context_ln: bool = True | |
skip_ln: bool = False | |
init_scale: float = 0.25 | |
flip_sin_to_cos: bool = False | |
use_checkpoint: bool = False | |
cfg: Config | |
def configure(self) -> None: | |
super().configure() | |
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) | |
self.backbone = UNetDiffusionTransformer( | |
n_ctx=self.cfg.n_ctx, | |
width=self.cfg.width, | |
layers=self.cfg.layers, | |
heads=self.cfg.heads, | |
skip_ln=self.cfg.skip_ln, | |
init_scale=init_scale, | |
use_checkpoint=self.cfg.use_checkpoint | |
) | |
self.ln_post = nn.LayerNorm(self.cfg.width) | |
self.input_proj = nn.Linear(self.cfg.input_channels, self.cfg.width) | |
self.output_proj = nn.Linear(self.cfg.width, self.cfg.output_channels) | |
# timestep embedding | |
self.time_embed = Timesteps(self.cfg.width, flip_sin_to_cos=self.cfg.flip_sin_to_cos, downscale_freq_shift=0) | |
self.time_proj = MLP(width=self.cfg.width, init_scale=init_scale) | |
if self.cfg.context_ln: | |
self.context_embed = nn.Sequential( | |
nn.LayerNorm(self.cfg.context_dim), | |
nn.Linear(self.cfg.context_dim, self.cfg.width), | |
) | |
else: | |
self.context_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) | |
if self.cfg.pretrained_model_name_or_path: | |
pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu") | |
_pretrained_ckpt = {} | |
for k, v in pretrained_ckpt.items(): | |
if k.startswith('denoiser_model.'): | |
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v | |
pretrained_ckpt = _pretrained_ckpt | |
if 'state_dict' in pretrained_ckpt: | |
_pretrained_ckpt = {} | |
for k, v in pretrained_ckpt['state_dict'].items(): | |
if k.startswith('denoiser_model.'): | |
_pretrained_ckpt[k.replace('denoiser_model.', '')] = v | |
pretrained_ckpt = _pretrained_ckpt | |
self.load_state_dict(pretrained_ckpt, strict=True) | |
def forward(self, | |
model_input: torch.FloatTensor, | |
timestep: torch.LongTensor, | |
context: torch.FloatTensor): | |
r""" | |
Args: | |
model_input (torch.FloatTensor): [bs, n_data, c] | |
timestep (torch.LongTensor): [bs,] | |
context (torch.FloatTensor): [bs, context_tokens, c] | |
Returns: | |
sample (torch.FloatTensor): [bs, n_data, c] | |
""" | |
_, n_data, _ = model_input.shape | |
# 1. time | |
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) | |
# 2. conditions projector | |
context = self.context_embed(context) | |
# 3. denoiser | |
x = self.input_proj(model_input) | |
x = torch.cat([t_emb, context, x], dim=1) | |
x = self.backbone(x) | |
x = self.ln_post(x) | |
x = x[:, -n_data:] # B, n_data, width | |
sample = self.output_proj(x) # B, n_data, embed_dim | |
return sample |