zwt123home123
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -47,3 +47,8 @@ class GroupedAutoEncoder(nn.Module):
|
|
47 |
# Concatenate groups back together
|
48 |
# reconstructed = torch.cat(decoded_groups, dim=1)
|
49 |
return reconstructed
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Concatenate groups back together
|
48 |
# reconstructed = torch.cat(decoded_groups, dim=1)
|
49 |
return reconstructed
|
50 |
+
|
51 |
+
input_dim = 5120
|
52 |
+
hidden_dim = 320
|
53 |
+
num_groups = 40
|
54 |
+
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()
|