zwt123home123's picture
Update model.py
84c306c verified
class GroupedAutoEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_groups):
super(GroupedAutoEncoder, self).__init__()
self.num_groups = num_groups
self.group_input_dim = input_dim // num_groups
self.group_hidden_dim = hidden_dim // num_groups
assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups."
assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups."
# Define group-wise encoders and decoders
self.encoders = nn.ModuleList([
nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False)
for _ in range(num_groups)
])
'''
self.decoders = nn.ModuleList([
nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False)
for _ in range(num_groups)
])
'''
self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
self.init_weights()
def init_weights(self):
for encoder in self.encoders:
nn.init.xavier_uniform_(encoder.weight)
#for decoder in self.decoders:
# nn.init.xavier_uniform_(decoder.weight)
nn.init.xavier_uniform_(self.decoder.weight)
def forward(self, x):
# Split input into groups
group_inputs = torch.split(x, self.group_input_dim, dim=1)
# Apply group-wise encoding
encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
# Apply group-wise decoding
#decoded_groups = [decoder(group) for group, decoder in zip(encoded_groups, self.decoders)]
reconstructed = self.decoder(torch.cat(encoded_groups,dim=1))
# Concatenate groups back together
# reconstructed = torch.cat(decoded_groups, dim=1)
return reconstructed
input_dim = 5120
hidden_dim = 320
num_groups = 40
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()