|
|
|
class GroupedAutoEncoder(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, num_groups): |
|
super(GroupedAutoEncoder, self).__init__() |
|
|
|
self.num_groups = num_groups |
|
self.group_input_dim = input_dim // num_groups |
|
self.group_hidden_dim = hidden_dim // num_groups |
|
|
|
assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups." |
|
assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups." |
|
|
|
|
|
self.encoders = nn.ModuleList([ |
|
nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False) |
|
for _ in range(num_groups) |
|
]) |
|
''' |
|
self.decoders = nn.ModuleList([ |
|
nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False) |
|
for _ in range(num_groups) |
|
]) |
|
''' |
|
self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
for encoder in self.encoders: |
|
nn.init.xavier_uniform_(encoder.weight) |
|
|
|
|
|
nn.init.xavier_uniform_(self.decoder.weight) |
|
|
|
def forward(self, x): |
|
|
|
group_inputs = torch.split(x, self.group_input_dim, dim=1) |
|
|
|
|
|
encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)] |
|
|
|
|
|
|
|
|
|
reconstructed = self.decoder(torch.cat(encoded_groups,dim=1)) |
|
|
|
|
|
|
|
return reconstructed |
|
|
|
input_dim = 5120 |
|
hidden_dim = 320 |
|
num_groups = 40 |
|
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda() |
|
|