my-fc-test-model / model_fc.py
lingy's picture
Upload model
24c13a4
raw
history blame contribute delete
442 Bytes
from transformers import PreTrainedModel
from .configuration_fc import FCConfig
from torch.nn import Linear
class FCModel(PreTrainedModel):
config_class = FCConfig
def __init__(self, config):
super().__init__(config)
self.model = Linear(in_features=10, out_features=config.num_nodes)
def forward(self, tensor):
# Use as forward similar to forward in torch
return self.model.forward(tensor)