Spaces:
Running
Running
import torch | |
from torch import nn | |
def mi(x: torch.Tensor) -> torch.Tensor: | |
return torch.sum(x, dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3]) | |
def sigma(x: torch.Tensor, epsilon=1e-5) -> torch.Tensor: | |
return torch.sqrt(torch.sum(((x - mi(x))**2 + epsilon), dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3])) | |
class AdaIN(nn.Module): | |
def __init__(self, epsilon=1e-5): | |
super().__init__() | |
self.epsilon = epsilon | |
def forward(self, content: torch.Tensor, style: torch.Tensor) -> torch.Tensor: | |
return (torch.mul(sigma(style, self.epsilon), ((content - mi(content)) / sigma(content, self.epsilon))) + mi(style)) | |