File size: 751 Bytes
a5bbcdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
import torch.nn as nn
from utils.bert_model import BertForSequenceEncoder
class sentence_retrieval_model(nn.Module):
def __init__(self, args):
super(sentence_retrieval_model, self).__init__()
self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
self.bert_hidden_dim = args['bert_hidden_dim']
self.dropout = nn.Dropout(args['dropout'])
self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
def forward(self, inp_tensor, msk_tensor, seg_tensor):
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
inputs = self.dropout(inputs)
score = self.proj_match(inputs).squeeze(-1)
score = torch.tanh(score)
return score |