import torch.nn as nn | |
class WRegressor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(768, 256), | |
nn.ReLU(), | |
nn.Dropout(), | |
nn.Linear(256, 64), | |
nn.ReLU(), | |
nn.Dropout(), | |
nn.Linear(64, 16), | |
nn.ReLU(), | |
nn.Dropout(), | |
nn.Linear(16, 1), | |
) | |
return | |
def forward(self, x): | |
r = self.linear_relu_stack(x) | |
return r |