|
import os |
|
import numpy |
|
|
|
import torch |
|
from torch import nn |
|
from PIL import Image |
|
from transformers import BertTokenizer |
|
|
|
from Model import clip |
|
from Model.bert import BertLMHeadModel, BertConfig |
|
from Model.clip.model import Transformer |
|
|
|
|
|
class Proj(nn.Module): |
|
|
|
def __init__(self, encoder_output_size, num_head=16): |
|
super().__init__() |
|
self.encoder_output_size = encoder_output_size |
|
|
|
self.transformer = Transformer(encoder_output_size, 1, num_head) |
|
self.linear = nn.Linear(encoder_output_size, 768) |
|
return |
|
|
|
def forward(self, x): |
|
x = x.permute(1, 0, 2) |
|
x = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
return self.linear(x) |
|
|
|
|
|
class TRCaptionNet(nn.Module): |
|
def __init__(self, config: dict): |
|
super().__init__() |
|
|
|
self.max_length = config["max_length"] |
|
self.proj_flag = config["proj"] |
|
assert type(self.proj_flag) == bool |
|
self.proj_num_head = config["proj_num_head"] |
|
|
|
|
|
self.vision_encoder, preprocess = clip.load(config["clip"], jit=False) |
|
self.vision_encoder.eval() |
|
self.vision_encoder = self.vision_encoder.visual.float() |
|
with torch.no_grad(): |
|
dummy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(next(self.parameters()).device) |
|
encoder_output_size = self.vision_encoder(dummy_input_image.unsqueeze(0)).shape[-1] |
|
|
|
|
|
if not os.path.isfile(config["bert"]): |
|
self.language_decoder = BertLMHeadModel.from_pretrained(config["bert"], |
|
is_decoder=True, |
|
add_cross_attention=True) |
|
self.tokenizer = BertTokenizer.from_pretrained(config["bert"]) |
|
else: |
|
med_config = BertConfig.from_json_file(config["bert"]) |
|
self.language_decoder = BertLMHeadModel(config=med_config) |
|
self.tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased") |
|
|
|
|
|
if self.proj_flag: |
|
if self.proj_num_head is None: |
|
self.proj = nn.Linear(encoder_output_size, 768) |
|
else: |
|
self.proj = Proj(encoder_output_size, self.proj_num_head) |
|
else: |
|
self.proj = None |
|
return |
|
|
|
@torch.no_grad() |
|
def generate(self, images, max_length: int = None, min_length: int = 12, num_beams: int = 3, |
|
repetition_penalty: float = 1.1): |
|
image_embeds = self.vision_encoder(images) |
|
|
|
if self.proj is not None: |
|
image_embeds = self.proj(image_embeds) |
|
|
|
image_atts = torch.ones(image_embeds.shape[:-1], dtype=torch.long).to(images.device) |
|
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts} |
|
|
|
input_ids = torch.ones((image_embeds.shape[0], 1), device=images.device, dtype=torch.long) |
|
input_ids *= 2 |
|
|
|
outputs = self.language_decoder.generate(input_ids=input_ids, |
|
max_length=self.max_length if max_length is None else max_length, |
|
min_length=min_length, |
|
num_beams=num_beams, |
|
eos_token_id=self.tokenizer.sep_token_id, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
repetition_penalty=repetition_penalty, |
|
**model_kwargs) |
|
|
|
captions = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
return captions |
|
|
|
|
|
def test(): |
|
model = TRCaptionNet({ |
|
"max_length": 35, |
|
"clip": "ViT-B/32", |
|
"bert": "dbmdz/bert-base-turkish-cased" |
|
}) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
test() |
|
|