from transformers import PreTrainedModel | |
from .configuration_fc import FCConfig | |
from torch.nn import Linear | |
class FCModel(PreTrainedModel): | |
config_class = FCConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = Linear(in_features=10, out_features=config.num_nodes) | |
def forward(self, tensor): | |
# Use as forward similar to forward in torch | |
return self.model.forward(tensor) | |