mboss's picture
Initial commit
d945eeb
raw
history blame
940 Bytes
import torch
import torch.nn as nn
class Modulation(nn.Module):
def __init__(
self,
embedding_dim: int,
condition_dim: int,
zero_init: bool = False,
single_layer: bool = False,
):
super().__init__()
self.silu = nn.SiLU()
if single_layer:
self.linear1 = nn.Identity()
else:
self.linear1 = nn.Linear(condition_dim, condition_dim)
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
# Only zero init the last linear layer
if zero_init:
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
emb = self.linear2(self.silu(self.linear1(condition)))
scale, shift = torch.chunk(emb, 2, dim=1)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x