liuhuadai's picture
Upload 340 files
6efc863 verified
raw
history blame
872 Bytes
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(
image_size: int,
is_train: bool,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
):
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose([
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
else:
return Compose([
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
_convert_to_rgb,
ToTensor(),
normalize,
])