Spaces:
Runtime error
Runtime error
File size: 1,848 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
from .base import BaseGenerator
from torchvision.transforms.functional import gaussian_blur
import torch.nn.functional as F
class PixelationGenerator(BaseGenerator):
def __init__(self, pixelation_size, **kwargs):
super().__init__(z_channels=0)
self.pixelation_size = pixelation_size
self.z_channels = 0
self.latent_space = None
def forward(self, img, condition, mask, **kwargs):
old_shape = img.shape[-2:]
img = F.interpolate(img, size=(
self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
img = F.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
out = img*(1-mask) + condition*mask
return {"img": out}
class MaskOutGenerator(BaseGenerator):
def __init__(self, noise: str, **kwargs):
super().__init__(z_channels=0)
self.noise = noise
self.z_channels = 0
assert self.noise in ["rand", "constant"]
self.latent_space = None
def forward(self, img, condition, mask, **kwargs):
if self.noise == "constant":
img = torch.zeros_like(img)
elif self.noise == "rand":
img = torch.rand_like(img)
out = img*(1-mask) + condition*mask
return {"img": out}
class IdentityGenerator(BaseGenerator):
def __init__(self):
super().__init__(z_channels=0)
def forward(self, img, condition, mask, **kwargs):
return dict(img=img)
class GaussianBlurGenerator(BaseGenerator):
def __init__(self):
super().__init__(z_channels=0)
self.sigma = 7
def forward(self, img, condition, mask, **kwargs):
img_blur = gaussian_blur(img, kernel_size=min(self.sigma*3, img.shape[-1]), sigma=self.sigma)
return dict(img=img * mask + (1-mask) * img_blur)
|