Spaces:
Runtime error
Runtime error
File size: 1,177 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from torch import Tensor
from transformers.image_utils import ImageInput
import torch
class AddGaussianNoise:
"""Add Gaussian noise to an image.
Args:
mean (float): mean of the Gaussian noise
std (float): standard deviation of the Gaussian noise
"""
def __init__(self, mean: float = 0.0, std: float = 1.0):
self.std = std
self.mean = mean
def __call__(self, tensor: Tensor) -> Tensor:
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self) -> str:
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class UnNest:
"""Un-nest the output of a feature extractor"""
def __init__(self, feature_extractor: callable):
self.feature_extractor = feature_extractor
def __call__(self, x: ImageInput) -> Tensor:
# Pass the input through the feature extractor
x = self.feature_extractor(x)
# Un-nest the pixel_values tensor
x = torch.tensor(x["pixel_values"][0])
# HuggingFace models expect 3D tensors [C, H, W]
return x if len(x) == 3 else x.unsqueeze(0)
|