Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn.functional as F | |
from torch import nn | |
class PreactResBlock(nn.Sequential): | |
def __init__(self, dim): | |
super().__init__( | |
nn.GroupNorm(dim // 16, dim), | |
nn.GELU(), | |
nn.Conv2d(dim, dim, 3, padding=1), | |
nn.GroupNorm(dim // 16, dim), | |
nn.GELU(), | |
nn.Conv2d(dim, dim, 3, padding=1), | |
) | |
def forward(self, x): | |
return x + super().forward(x) | |
class UNetBlock(nn.Module): | |
def __init__(self, input_dim, output_dim=None, scale_factor=1.0): | |
super().__init__() | |
if output_dim is None: | |
output_dim = input_dim | |
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1) | |
self.res_block1 = PreactResBlock(output_dim) | |
self.res_block2 = PreactResBlock(output_dim) | |
self.downsample = self.upsample = nn.Identity() | |
if scale_factor > 1: | |
self.upsample = nn.Upsample(scale_factor=scale_factor) | |
elif scale_factor < 1: | |
self.downsample = nn.Upsample(scale_factor=scale_factor) | |
def forward(self, x, h=None): | |
""" | |
Args: | |
x: (b c h w), last output | |
h: (b c h w), skip output | |
Returns: | |
o: (b c h w), output | |
s: (b c h w), skip output | |
""" | |
x = self.upsample(x) | |
if h is not None: | |
assert x.shape == h.shape, f"{x.shape} != {h.shape}" | |
x = x + h | |
x = self.pre_conv(x) | |
x = self.res_block1(x) | |
x = self.res_block2(x) | |
return self.downsample(x), x | |
class UNet(nn.Module): | |
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2): | |
super().__init__() | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) | |
self.encoder_blocks = nn.ModuleList( | |
[ | |
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5) | |
for i in range(num_blocks) | |
] | |
) | |
self.middle_blocks = nn.ModuleList( | |
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)] | |
) | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2) | |
for i in reversed(range(num_blocks)) | |
] | |
) | |
self.head = nn.Sequential( | |
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), | |
nn.GELU(), | |
nn.Conv2d(hidden_dim, output_dim, 1), | |
) | |
def scale_factor(self): | |
return 2 ** len(self.encoder_blocks) | |
def pad_to_fit(self, x): | |
""" | |
Args: | |
x: (b c h w), input | |
Returns: | |
x: (b c h' w'), padded input | |
""" | |
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor | |
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor | |
return F.pad(x, (0, wpad, 0, hpad)) | |
def forward(self, x): | |
""" | |
Args: | |
x: (b c h w), input | |
Returns: | |
o: (b c h w), output | |
""" | |
shape = x.shape | |
x = self.pad_to_fit(x) | |
x = self.input_proj(x) | |
s_list = [] | |
for block in self.encoder_blocks: | |
x, s = block(x) | |
s_list.append(s) | |
for block in self.middle_blocks: | |
x, _ = block(x) | |
for block, s in zip(self.decoder_blocks, reversed(s_list)): | |
x, _ = block(x, s) | |
x = self.head(x) | |
x = x[..., : shape[2], : shape[3]] | |
return x | |
def test(self, shape=(3, 512, 256)): | |
import ptflops | |
macs, params = ptflops.get_model_complexity_info( | |
self, | |
shape, | |
as_strings=True, | |
print_per_layer_stat=True, | |
verbose=True, | |
) | |
print(f"macs: {macs}") | |
print(f"params: {params}") | |
def main(): | |
model = UNet(3, 3) | |
model.test() | |
if __name__ == "__main__": | |
main() | |