File size: 882 Bytes
db26c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torch import nn
from .utils import get_act
from .norm import Norm1D, Norm2D
class Dense(nn.Module):
def __init__(self, input_dim, output_dim, bias=True, norm1d='none', norm2d='none', act='none'):
super(Dense, self).__init__()
assert norm1d == 'none' or norm2d == 'none', "one of [norm1d, norm2d] must be none"
if norm1d != 'none':
self.nn_norm = Norm1D(input_dim, norm1d)
elif norm2d != 'none':
self.nn_norm = Norm2D(input_dim, norm2d)
else:
self.nn_norm = None
self.nn_act = get_act(act)
self.nn_linear = nn.Linear(input_dim, output_dim, bias)
def weight(self):
return self.nn_linear.weight
def forward(self, x):
if self.nn_norm is not None:
x = self.nn_norm(x)
x = self.nn_act(x)
x = self.nn_linear(x)
return x
|