File size: 1,404 Bytes
f7a83c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Any, Dict, List
import torch
class CaptionCollator(object):
def __init__(self, tokenizer, max_seq_length):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
captions, patch_images = [], []
for data in features:
# 如果图片预处理失败,则跳过该图片
if data['patch_image'] is None:
continue
captions.append(data['caption'])
patch_images.append(data['patch_image'])
# 获得encoder的输入
input_ids = self.tokenizer(
['图片描述了什么?']*len(captions), return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True
).input_ids
patch_images = torch.concat(patch_images, dim=0)
# 获得decoder的输入
inputs = self.tokenizer(
captions, return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True
)
decoder_input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
inputs = {
'input_ids': input_ids,
'patch_images': patch_images,
'decoder_input_ids': decoder_input_ids,
'attention_mask': attention_mask,
'return_loss': True
}
return inputs
|