|
from typing import Any, Dict, List |
|
import torch |
|
|
|
|
|
class CaptionCollator(object): |
|
def __init__(self, tokenizer, max_seq_length): |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_seq_length |
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
captions, patch_images = [], [] |
|
for data in features: |
|
|
|
if data['patch_image'] is None: |
|
continue |
|
captions.append(data['caption']) |
|
patch_images.append(data['patch_image']) |
|
|
|
input_ids = self.tokenizer( |
|
['图片描述了什么?']*len(captions), return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True |
|
).input_ids |
|
patch_images = torch.concat(patch_images, dim=0) |
|
|
|
|
|
inputs = self.tokenizer( |
|
captions, return_tensors="pt", max_length=self.max_seq_length, truncation=True, padding=True |
|
) |
|
decoder_input_ids = inputs.input_ids |
|
attention_mask = inputs.attention_mask |
|
|
|
inputs = { |
|
'input_ids': input_ids, |
|
'patch_images': patch_images, |
|
'decoder_input_ids': decoder_input_ids, |
|
'attention_mask': attention_mask, |
|
'return_loss': True |
|
} |
|
return inputs |
|
|