albef-vqa / data /transforms.py
ryanramos's picture
Add source code
d1b8c9b
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import re
from typing import List, Tuple, Union
import torch
from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate
from torchvision import transforms
from transformers.models.bert.tokenization_bert import BertTokenizer
# mean and standard deviation from the ALBEF repo:
# https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16
MEAN = (0.48145466, 0.4578275, 0.40821073)
STD_DEV = (0.26862954, 0.26130258, 0.27577711)
class ALBEFTextTransform:
"""
Remove punctuations and trailing spaces in input text and transform it into
a Tensor of token ids using BERTTokenizer.
Args:
pretrained_tokenizer (str): Pretrained tokenizer to use.
Default: "bert-base-uncased"
do_pre_process (bool): Whether to pre-process input text.
Defaults to True.
truncate (bool): Whether to truncate input text to max_seq_length.
Defaults to False.
pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length.
add_end_token (bool): Whether to add the end-of-sentence token.
Defaults to True.
max_seq_len (int): The max sequence length after truncating or padding.
Defaults to 25.
cls_token_id (int): Value to represent the start of each text.
Defaults to 101, Hugging Face's BERT cls token id.
sep_token_id (int): Value to represent the end of each text.
Defaults to 102, Hugging Face's BERT sep token id.
pad_token_id (int): Value with which to pad each text so that all texts are the same length.
Defaults to 0, Hugging Face's BERT pad token id.
Inputs:
text (Union[List[str], str]): Input text to transform.
"""
def __init__(
self,
pretrained_tokenizer: str = "bert-base-uncased",
do_pre_process: bool = True,
truncate: bool = False,
pad_to_max_seq_len: bool = False,
add_end_token: bool = True,
max_seq_len: int = 25,
cls_token_id: int = 101,
sep_token_id: int = 102,
pad_token_id: int = 0,
):
self.do_pre_process = do_pre_process
self.cls_token_id = cls_token_id
self.sep_token_id = sep_token_id
self.pad_token_id = pad_token_id
self.add_end_token = add_end_token
self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer)
self.transform = Sequential(
Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
ToTensor(padding_value=self.pad_token_id),
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
if pad_to_max_seq_len
else torch.nn.Identity(),
)
def pre_process(self, text: str) -> str:
text = (
re.sub(
r"([,.'!?\"()*#:;~])",
"",
text,
)
.replace("-", " ")
.replace("/", " ")
)
text = text.rstrip(" ")
return text
def __call__(self, text: Union[List[str], str]) -> torch.Tensor:
if self.do_pre_process:
if isinstance(text, str):
text = self.pre_process(text)
else:
text = [self.pre_process(t) for t in text]
tokens = self.tokenizer(text)["input_ids"]
if not self.add_end_token and tokens[-1] == self.sep_token_id:
tokens = tokens[:-1]
input_ids = self.transform(tokens)
return input_ids
def training_image_transform(
image_size: int = 384,
scale: Tuple[float, float] = (0.5, 1.0),
image_interpolation=transforms.InterpolationMode.BICUBIC,
mean: Tuple[float, float, float] = MEAN,
std_dev: Tuple[float, float, float] = STD_DEV,
) -> transforms.Compose:
return transforms.Compose(
[
transforms.RandomResizedCrop(
image_size, scale=scale, interpolation=image_interpolation
),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(2, 7),
transforms.ToTensor(),
transforms.Normalize(mean, std_dev),
]
)
def testing_image_transform(
image_size: int = 384,
image_interpolation=transforms.InterpolationMode.BICUBIC,
mean: Tuple[float, float, float] = MEAN,
std_dev: Tuple[float, float, float] = STD_DEV,
) -> transforms.Compose:
return transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=image_interpolation
),
transforms.ToTensor(),
transforms.Normalize(mean, std_dev),
]
)