Spaces:
Runtime error
Runtime error
from torch import Tensor | |
from transformers.image_utils import ImageInput | |
import torch | |
class AddGaussianNoise: | |
"""Add Gaussian noise to an image. | |
Args: | |
mean (float): mean of the Gaussian noise | |
std (float): standard deviation of the Gaussian noise | |
""" | |
def __init__(self, mean: float = 0.0, std: float = 1.0): | |
self.std = std | |
self.mean = mean | |
def __call__(self, tensor: Tensor) -> Tensor: | |
return tensor + torch.randn(tensor.size()) * self.std + self.mean | |
def __repr__(self) -> str: | |
return self.__class__.__name__ + "(mean={0}, std={1})".format( | |
self.mean, self.std | |
) | |
class UnNest: | |
"""Un-nest the output of a feature extractor""" | |
def __init__(self, feature_extractor: callable): | |
self.feature_extractor = feature_extractor | |
def __call__(self, x: ImageInput) -> Tensor: | |
# Pass the input through the feature extractor | |
x = self.feature_extractor(x) | |
# Un-nest the pixel_values tensor | |
x = torch.tensor(x["pixel_values"][0]) | |
# HuggingFace models expect 3D tensors [C, H, W] | |
return x if len(x) == 3 else x.unsqueeze(0) | |