Spaces:
Runtime error
Runtime error
import torch; torch.manual_seed(0) | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils | |
import torch.distributions | |
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_activation(activation): | |
if activation == 'tanh': | |
activ = F.tanh | |
elif activation == 'relu': | |
activ = F.relu | |
elif activation == 'mish': | |
activ = F.mish | |
elif activation == 'sigmoid': | |
activ = torch.sigmoid | |
elif activation == 'leakyrelu': | |
activ = F.leaky_relu | |
elif activation == 'exp': | |
activ = torch.exp | |
else: | |
raise ValueError | |
return activ | |
class SimpleNet(nn.Module): | |
def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None): | |
super(SimpleNet, self).__init__() | |
self.linears = nn.ModuleList() | |
self.dropouts = nn.ModuleList() | |
self.output_dim = output_dim | |
dims = [input_dim] + hidden_dims + [output_dim] | |
for d_in, d_out in zip(dims[:-1], dims[1:]): | |
self.linears.append(nn.Linear(d_in, d_out)) | |
self.dropouts.append(nn.Dropout(dropout)) | |
self.activation = get_activation(activation) | |
self.n_layers = len(self.linears) | |
self.layer_range = range(self.n_layers) | |
if final_activ != None: | |
self.final_activ = get_activation(final_activ) | |
self.use_final_activ = True | |
else: | |
self.use_final_activ = False | |
def forward(self, x): | |
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
x = layer(x) | |
if i_layer != self.n_layers - 1: | |
x = self.activation(dropout(x)) | |
if self.use_final_activ: x = self.final_activ(x) | |
return x | |