File size: 1,639 Bytes
2c5aba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, vit, roberta, tokenizer, device):
super().__init__()
self.bertmap = nn.Conv1d(768, 768, 1)
self.vitmap = nn.Conv1d(768, 768, 1)
self.conv1d = nn.Conv1d(1, 1, 1)
self.add_module("vit", vit)
self.add_module("roberta", roberta)
self.tokenizer = tokenizer
self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]]))
self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.]))
self.device = device
def forward(self, image, cats):
vit_out = self.vit(image)
vit_out = vit_out[:,1:vit_out.shape[1],:]
vit_out = torch.transpose(vit_out, 2,1)
vit_out = self.vitmap(vit_out)
vit_out = torch.transpose(vit_out, 2,1)
token_out = self.tokenizer.encode_plus(
cats,
padding=True,
add_special_tokens=True,
return_token_type_ids=True,
return_tensors='pt'
).to(self.device)
bert_out = self.roberta(**token_out)
hidden_state = bert_out.last_hidden_state
hidden_state = torch.transpose(hidden_state, 2,1)
hidden_state = self.bertmap(hidden_state)
hidden_state = torch.transpose(hidden_state, 2,1)
pooled_bert_out = hidden_state[:, 0]
pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2)
out = torch.matmul(vit_out, pooled_bert_out)
out = torch.transpose(out, 2,1)
return torch.squeeze(self.conv1d(out), dim=1) |