Spaces:
Sleeping
Sleeping
""" MLP module w/ dropout and configurable activation layer | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
from torch import nn as nn | |
class Mlp(nn.Module): | |
""" MLP as used in Vision Transformer, MLP-Mixer and related networks | |
""" | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class GluMlp(nn.Module): | |
""" MLP w/ GLU style gating | |
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 | |
""" | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
assert hidden_features % 2 == 0 | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features // 2, out_features) | |
self.drop = nn.Dropout(drop) | |
def init_weights(self): | |
# override init of fc1 w/ gate portion set to weight near zero, bias=1 | |
fc1_mid = self.fc1.bias.shape[0] // 2 | |
nn.init.ones_(self.fc1.bias[fc1_mid:]) | |
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) | |
def forward(self, x): | |
x = self.fc1(x) | |
x, gates = x.chunk(2, dim=-1) | |
x = x * self.act(gates) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class GatedMlp(nn.Module): | |
""" MLP as used in gMLP | |
""" | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, | |
gate_layer=None, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
if gate_layer is not None: | |
assert hidden_features % 2 == 0 | |
self.gate = gate_layer(hidden_features) | |
hidden_features = hidden_features // 2 # FIXME base reduction on gate property? | |
else: | |
self.gate = nn.Identity() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.gate(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class ConvMlp(nn.Module): | |
""" MLP using 1x1 convs that keeps spatial dims | |
""" | |
def __init__( | |
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) | |
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() | |
self.act = act_layer() | |
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.norm(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
return x | |