mawairon commited on
Commit
7680154
1 Parent(s): 99d9c5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -67,13 +67,13 @@ def load_model(model_name: str):
67
  seq_drop_prob = 0.05
68
  train_sequence_length = 8000
69
  weight_decay = 0.0001
70
- num_labs = len(set(y_train))
71
 
72
 
73
  model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
74
  new_head = torch.nn.Sequential(
75
  torch.nn.Dropout(0.5),
76
- MLP([hidden_dim*2 , num_labs])
77
  )
78
 
79
  model = torch.nn.Sequential(
 
67
  seq_drop_prob = 0.05
68
  train_sequence_length = 8000
69
  weight_decay = 0.0001
70
+ num_countries = 38
71
 
72
 
73
  model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
74
  new_head = torch.nn.Sequential(
75
  torch.nn.Dropout(0.5),
76
+ MLP([hidden_dim*2 , num_countries])
77
  )
78
 
79
  model = torch.nn.Sequential(