File size: 751 Bytes
a5bbcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn as nn

from utils.bert_model import BertForSequenceEncoder

class sentence_retrieval_model(nn.Module):
    def __init__(self, args):
        super(sentence_retrieval_model, self).__init__()
        self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
        self.bert_hidden_dim = args['bert_hidden_dim']
        self.dropout = nn.Dropout(args['dropout'])
        self.proj_match = nn.Linear(self.bert_hidden_dim, 1)


    def forward(self, inp_tensor, msk_tensor, seg_tensor):
        _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
        inputs = self.dropout(inputs)
        score = self.proj_match(inputs).squeeze(-1)
        score = torch.tanh(score)
        return score