baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
8.22 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.smb.level import MarioLevel
from performer_pytorch import SelfAttention
class CustomUpsample(nn.Module):
def __init__(self, in_channels, out_channels, target_size):
super(CustomUpsample, self).__init__()
if target_size == 4: # Upsampling from 2x2 to 4x4
stride, kernel_size, padding = 1, 3, 0
elif target_size == 7: # Upsampling from 4x4 to 7x7
stride, kernel_size, padding = 2, 3, 1
elif target_size == 8: # Upsampling from 4x4 to 8x8
stride, kernel_size, padding = 2, 4, 1
elif target_size == 14: # Upsampling from 7x7 to 14x14
stride, kernel_size, padding = 2, 2, 0
elif target_size == 16: # Upsampling from 8x8 to 16x16
stride, kernel_size, padding = 2, 4, 1
else:
raise ValueError("Invalid target_size specified.")
self.upsample = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x):
return self.upsample(x)
class PerformerSelfAttention(nn.Module):
def __init__(self, channels, size, n_heads=4):
super(PerformerSelfAttention, self).__init__()
self.channels = channels
self.size = size
self.n_heads = n_heads
# Use PerformerSelfAttention from performer-pytorch
self.performer_attention = SelfAttention(
dim=channels,
heads=self.n_heads
)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
batch_size = x.size(0)
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# Adapt the input shape for PerformerSelfAttention
query_key_value = x_ln.view(batch_size, self.size * self.size, self.channels)
attention_value = self.performer_attention(query_key_value)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
# Define the two convolution layers
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
self.group_norm1 = nn.GroupNorm(1, mid_channels)
self.gelu1 = nn.GELU()
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.group_norm2 = nn.GroupNorm(1, out_channels)
def forward(self, x):
# Apply the first convolution layer
x1 = self.conv1(x)
x1 = self.group_norm1(x1)
x1 = self.gelu1(x1)
# Apply the second convolution layer
x2 = self.conv2(x1)
x2 = self.group_norm2(x2)
# Apply residual connection and GELU activation
if self.residual:
return F.gelu(x + x2)
else:
return x2
class Down(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=32):
super().__init__()
# Max pooling followed by two DoubleConv layers
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2, ceil_mode=True),
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels),
)
# Embedding layer to incorporate time information
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, t):
x = self.maxpool_conv(x)
# Apply the embedding layer and broadcast the output to match spatial dimensions
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class Up(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=32, target_size=7):
super().__init__()
self.up = CustomUpsample(in_channels=int(in_channels/2), out_channels=int(in_channels/2), target_size=target_size)
# DoubleConv layers after concatenation
self.conv = nn.Sequential(
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels, in_channels // 2),
)
# Embedding layer to incorporate time information
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, skip_x, t):
# Upsample the input tensor
x = self.up(x)
# Concatenate the upsampled tensor with the skip tensor from the encoder
# print('x.shape: {}, skip_x.shape: {}'.format(x.shape, skip_x.shape))
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
# Apply the embedding layer and broadcast the output to match spatial dimensions
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class UNet(nn.Module):
def __init__(self, c_in=MarioLevel.n_types, c_out=MarioLevel.n_types, time_dim=32, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64) # 64x16x16
self.down1 = Down(64, 128) # 128x8x8
self.sa1 = PerformerSelfAttention(128, 8) # 128x8x8
self.down2 = Down(128, 256) # 256x4x4
self.sa2 = PerformerSelfAttention(256, 4) # 256x4x4
self.down3 = Down(256, 256) # 256x2x2
self.sa3 = PerformerSelfAttention(256, 2) # 256x2x2
self.bot1 = DoubleConv(256, 512)
self.bot2 = DoubleConv(512, 512)
self.bot3 = DoubleConv(512, 256) # 256x2x2
self.up1 = Up(512, 128, target_size=4) # 128x4x4
self.sa4 = PerformerSelfAttention(128, 4) # 256x4x4
self.up2 = Up(256, 64, target_size=8) # 64x8x8
self.sa5 = PerformerSelfAttention(64, 8) # 128x8x8
self.up3 = Up(128, 64, target_size=16) # 64x16x16
self.sa6 = PerformerSelfAttention(64, 16) # 64x16x16
self.outc = nn.Conv2d(64, c_out, kernel_size=1) # 11x16x16
def pos_encoding(self, t, channels):
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
def forward(self, x, t):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)
x1 = self.inc(x) # 64x16x16
x2 = self.down1(x1, t) # 128x8x8
# try:
x2 = self.sa1(x2) # 128x8x8
# except RuntimeError:
# print(x.shape, x2.shape)
x3 = self.down2(x2, t) # 256x4x4
x3 = self.sa2(x3) # 256x4x4
x4 = self.down3(x3, t) # 256x2x2
x4 = self.sa3(x4) # 256x2x2
x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4) # 256x2x2
x = self.up1(x4, x3, t) # 256x4x4
x = self.sa4(x) # 256x4x4
x = self.up2(x, x2, t) # 128x8x8
x = self.sa5(x) # 128x8x8
x = self.up3(x, x1, t) # 64x16x16
x = self.sa6(x) # 64x16x16
output = self.outc(x) # 11x16x16
return output