Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from ldm.modules.attention import default, zero_module, checkpoint | |
from ldm.modules.diffusionmodules.openaimodel import UNetModel | |
from ldm.modules.diffusionmodules.util import timestep_embedding | |
class DepthAttention(nn.Module): | |
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.dim_head = dim_head | |
self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False) | |
self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) | |
self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) | |
if output_bias: | |
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1) | |
else: | |
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False) | |
def forward(self, x, context): | |
""" | |
@param x: b,f0,h,w | |
@param context: b,f1,d,h,w | |
@return: | |
""" | |
hn, hd = self.heads, self.dim_head | |
b, _, h, w = x.shape | |
b, _, d, h, w = context.shape | |
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w | |
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w | |
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w | |
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w | |
attn = sim.softmax(dim=2) | |
# b,hn,hd,d,h,w * b,hn,1,d,h,w | |
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w | |
out = out.reshape(b,hn*hd,h,w) | |
return self.to_out(out) | |
class DepthTransformer(nn.Module): | |
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): | |
super().__init__() | |
inner_dim = n_heads * d_head | |
self.proj_in = nn.Sequential( | |
nn.Conv2d(dim, inner_dim, 1, 1), | |
nn.GroupNorm(8, inner_dim), | |
nn.SiLU(True), | |
) | |
self.proj_context = nn.Sequential( | |
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias | |
nn.GroupNorm(8, context_dim), | |
nn.ReLU(True), # only relu, because we want input is 0, output is 0 | |
) | |
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn | |
self.proj_out = nn.Sequential( | |
nn.GroupNorm(8, inner_dim), | |
nn.ReLU(True), | |
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False), | |
nn.GroupNorm(8, inner_dim), | |
nn.ReLU(True), | |
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)), | |
) | |
self.checkpoint = checkpoint | |
def forward(self, x, context=None): | |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) | |
def _forward(self, x, context): | |
x_in = x | |
x = self.proj_in(x) | |
context = self.proj_context(context) | |
x = self.depth_attn(x, context) | |
x = self.proj_out(x) + x_in | |
return x | |
class DepthWiseAttention(UNetModel): | |
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# num_heads = 4 | |
model_channels = kwargs['model_channels'] | |
channel_mult = kwargs['channel_mult'] | |
d0,d1,d2,d3 = volume_dims | |
# 4 | |
ch = model_channels*channel_mult[2] | |
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3) | |
self.output_conditions=nn.ModuleList() | |
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8} | |
# 8 | |
ch = model_channels*channel_mult[2] | |
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 | |
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 | |
# 16 | |
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 | |
ch = model_channels*channel_mult[1] | |
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 | |
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 | |
# 32 | |
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 | |
ch = model_channels*channel_mult[0] | |
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 | |
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 | |
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 | |
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs): | |
hs = [] | |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
emb = self.time_embed(t_emb) | |
h = x.type(self.dtype) | |
for index, module in enumerate(self.input_blocks): | |
h = module(h, emb, context) | |
hs.append(h) | |
h = self.middle_block(h, emb, context) | |
h = self.middle_conditions(h, context=source_dict[h.shape[-1]]) | |
for index, module in enumerate(self.output_blocks): | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = module(h, emb, context) | |
if index in self.output_b2c: | |
layer = self.output_conditions[self.output_b2c[index]] | |
h = layer(h, context=source_dict[h.shape[-1]]) | |
h = h.type(x.dtype) | |
return self.out(h) | |
def get_trainable_parameters(self): | |
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()] | |
return paras | |