zwt123home123 commited on
Commit
84c306c
·
verified ·
1 Parent(s): f3b5cf2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -0
model.py CHANGED
@@ -47,3 +47,8 @@ class GroupedAutoEncoder(nn.Module):
47
  # Concatenate groups back together
48
  # reconstructed = torch.cat(decoded_groups, dim=1)
49
  return reconstructed
 
 
 
 
 
 
47
  # Concatenate groups back together
48
  # reconstructed = torch.cat(decoded_groups, dim=1)
49
  return reconstructed
50
+
51
+ input_dim = 5120
52
+ hidden_dim = 320
53
+ num_groups = 40
54
+ model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()