import torch.nn as nn class FC(nn.Module): def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): super(FC, self).__init__() self.dropout_r = dropout_r self.use_relu = use_relu self.linear = nn.Linear(in_size, out_size) if use_relu: self.relu = nn.ReLU(inplace=True) if dropout_r > 0: self.dropout = nn.Dropout(dropout_r) def forward(self, x): x = self.linear(x) if self.use_relu: x = self.relu(x) if self.dropout_r > 0: x = self.dropout(x) return x class MLP(nn.Module): def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): super(MLP, self).__init__() self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) self.linear = nn.Linear(mid_size, out_size) def forward(self, x): return self.linear(self.fc(x))