|
import torch |
|
import numpy as np |
|
|
|
|
|
class Normalize(object): |
|
def __init__(self, mean, var): |
|
self.mean = mean |
|
self.var = var |
|
|
|
def __call__(self, sample): |
|
if isinstance(sample, dict): |
|
img = sample['img'] |
|
gt = sample['gt'] |
|
img = (img - self.mean) / self.var |
|
sample = {'img': img, 'gt': gt} |
|
else: |
|
sample = (sample - self.mean) / self.var |
|
|
|
return sample |
|
|
|
|
|
|
|
class RandHorizontalFlip(object): |
|
def __init__(self, prob_aug): |
|
self.prob_aug = prob_aug |
|
|
|
def __call__(self, sample): |
|
p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) |
|
prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) |
|
|
|
if isinstance(sample, dict): |
|
img = sample['img'] |
|
gt = sample['gt'] |
|
|
|
if prob_lr > 0.5: |
|
img = np.fliplr(img).copy() |
|
sample = {'img': img, 'gt': gt} |
|
else: |
|
if prob_lr > 0.5: |
|
sample = np.fliplr(sample).copy() |
|
return sample |
|
|
|
|
|
class ToTensor(object): |
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, sample): |
|
if isinstance(sample, dict): |
|
img = sample['img'] |
|
gt = sample['gt'] |
|
img = torch.from_numpy(img).type(torch.FloatTensor) |
|
gt = torch.from_numpy(gt).type(torch.FloatTensor) |
|
sample = {'img': img, 'gt': gt} |
|
else: |
|
sample = torch.from_numpy(sample).type(torch.FloatTensor) |
|
return sample |
|
|