File size: 1,639 Bytes
2c5aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, vit, roberta, tokenizer, device):
        super().__init__()
        self.bertmap = nn.Conv1d(768, 768, 1)
        self.vitmap = nn.Conv1d(768, 768, 1)
        self.conv1d = nn.Conv1d(1, 1, 1)
        self.add_module("vit", vit)
        self.add_module("roberta", roberta)
        self.tokenizer = tokenizer
        self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]]))
        self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.]))
        self.device = device
        
    def forward(self, image, cats):
        vit_out = self.vit(image)
        vit_out = vit_out[:,1:vit_out.shape[1],:]
        vit_out = torch.transpose(vit_out, 2,1)
        vit_out = self.vitmap(vit_out)
        vit_out = torch.transpose(vit_out, 2,1)
        token_out = self.tokenizer.encode_plus(
                        cats,
                        padding=True,
                        add_special_tokens=True,
                        return_token_type_ids=True,
                        return_tensors='pt'
                    ).to(self.device)
        bert_out = self.roberta(**token_out)
        hidden_state = bert_out.last_hidden_state
        hidden_state = torch.transpose(hidden_state, 2,1)
        hidden_state = self.bertmap(hidden_state)
        hidden_state = torch.transpose(hidden_state, 2,1)
        pooled_bert_out = hidden_state[:, 0]
        pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2)
        out = torch.matmul(vit_out, pooled_bert_out)
        out = torch.transpose(out, 2,1)
        return torch.squeeze(self.conv1d(out), dim=1)