Spaces:
Runtime error
Runtime error
""" | |
File copied from | |
https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/distributions.py | |
""" | |
import torch | |
import torch.distributions as distr | |
import torch.nn.functional as F | |
from torch import Tensor | |
class BinaryConcrete(distr.relaxed_bernoulli.RelaxedBernoulli): | |
def __init__(self, temperature: Tensor, logits: Tensor): | |
super().__init__(temperature=temperature, logits=logits) | |
self.device = self.temperature.device | |
def cdf(self, value: Tensor) -> Tensor: | |
return torch.sigmoid( | |
(torch.log(value) - torch.log(1.0 - value)) * self.temperature - self.logits | |
) | |
def log_prob(self, value: Tensor) -> Tensor: | |
return torch.where( | |
(value > 0) & (value < 1), | |
super().log_prob(value), | |
torch.full_like(value, -float("inf")), | |
) | |
def log_expected_L0(self, value: Tensor) -> Tensor: | |
return -F.softplus( | |
(torch.log(value) - torch.log(1 - value)) * self.temperature - self.logits | |
) | |
class Streched(distr.TransformedDistribution): | |
def __init__(self, base_dist, l: float = -0.1, r: float = 1.1): | |
super().__init__(base_dist, distr.AffineTransform(loc=l, scale=r - l)) | |
def log_expected_L0(self) -> Tensor: | |
value = torch.tensor(0.0, device=self.base_dist.device) | |
for transform in self.transforms[::-1]: | |
value = transform.inv(value) | |
if self._validate_args: | |
self.base_dist._validate_sample(value) | |
value = self.base_dist.log_expected_L0(value) | |
value = self._monotonize_cdf(value) | |
return value | |
def expected_L0(self) -> Tensor: | |
return self.log_expected_L0().exp() | |
class RectifiedStreched(Streched): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def sample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor: | |
return self.rsample(sample_shape) | |
def rsample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor: | |
x = super().rsample(sample_shape) | |
return x.clamp(0, 1) | |