Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import re | |
from typing import List, Tuple, Union | |
import torch | |
from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate | |
from torchvision import transforms | |
from transformers.models.bert.tokenization_bert import BertTokenizer | |
# mean and standard deviation from the ALBEF repo: | |
# https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16 | |
MEAN = (0.48145466, 0.4578275, 0.40821073) | |
STD_DEV = (0.26862954, 0.26130258, 0.27577711) | |
class ALBEFTextTransform: | |
""" | |
Remove punctuations and trailing spaces in input text and transform it into | |
a Tensor of token ids using BERTTokenizer. | |
Args: | |
pretrained_tokenizer (str): Pretrained tokenizer to use. | |
Default: "bert-base-uncased" | |
do_pre_process (bool): Whether to pre-process input text. | |
Defaults to True. | |
truncate (bool): Whether to truncate input text to max_seq_length. | |
Defaults to False. | |
pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length. | |
add_end_token (bool): Whether to add the end-of-sentence token. | |
Defaults to True. | |
max_seq_len (int): The max sequence length after truncating or padding. | |
Defaults to 25. | |
cls_token_id (int): Value to represent the start of each text. | |
Defaults to 101, Hugging Face's BERT cls token id. | |
sep_token_id (int): Value to represent the end of each text. | |
Defaults to 102, Hugging Face's BERT sep token id. | |
pad_token_id (int): Value with which to pad each text so that all texts are the same length. | |
Defaults to 0, Hugging Face's BERT pad token id. | |
Inputs: | |
text (Union[List[str], str]): Input text to transform. | |
""" | |
def __init__( | |
self, | |
pretrained_tokenizer: str = "bert-base-uncased", | |
do_pre_process: bool = True, | |
truncate: bool = False, | |
pad_to_max_seq_len: bool = False, | |
add_end_token: bool = True, | |
max_seq_len: int = 25, | |
cls_token_id: int = 101, | |
sep_token_id: int = 102, | |
pad_token_id: int = 0, | |
): | |
self.do_pre_process = do_pre_process | |
self.cls_token_id = cls_token_id | |
self.sep_token_id = sep_token_id | |
self.pad_token_id = pad_token_id | |
self.add_end_token = add_end_token | |
self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer) | |
self.transform = Sequential( | |
Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(), | |
ToTensor(padding_value=self.pad_token_id), | |
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id) | |
if pad_to_max_seq_len | |
else torch.nn.Identity(), | |
) | |
def pre_process(self, text: str) -> str: | |
text = ( | |
re.sub( | |
r"([,.'!?\"()*#:;~])", | |
"", | |
text, | |
) | |
.replace("-", " ") | |
.replace("/", " ") | |
) | |
text = text.rstrip(" ") | |
return text | |
def __call__(self, text: Union[List[str], str]) -> torch.Tensor: | |
if self.do_pre_process: | |
if isinstance(text, str): | |
text = self.pre_process(text) | |
else: | |
text = [self.pre_process(t) for t in text] | |
tokens = self.tokenizer(text)["input_ids"] | |
if not self.add_end_token and tokens[-1] == self.sep_token_id: | |
tokens = tokens[:-1] | |
input_ids = self.transform(tokens) | |
return input_ids | |
def training_image_transform( | |
image_size: int = 384, | |
scale: Tuple[float, float] = (0.5, 1.0), | |
image_interpolation=transforms.InterpolationMode.BICUBIC, | |
mean: Tuple[float, float, float] = MEAN, | |
std_dev: Tuple[float, float, float] = STD_DEV, | |
) -> transforms.Compose: | |
return transforms.Compose( | |
[ | |
transforms.RandomResizedCrop( | |
image_size, scale=scale, interpolation=image_interpolation | |
), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandAugment(2, 7), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std_dev), | |
] | |
) | |
def testing_image_transform( | |
image_size: int = 384, | |
image_interpolation=transforms.InterpolationMode.BICUBIC, | |
mean: Tuple[float, float, float] = MEAN, | |
std_dev: Tuple[float, float, float] = STD_DEV, | |
) -> transforms.Compose: | |
return transforms.Compose( | |
[ | |
transforms.Resize( | |
(image_size, image_size), interpolation=image_interpolation | |
), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std_dev), | |
] | |
) | |