arikat commited on
Commit
3af12a2
1 Parent(s): 9bce414

minor edits

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -27,13 +27,13 @@ tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") #fa
27
  label_encoderfam = LabelEncoder()
28
  encoded_labelsfam = label_encoderfam.fit_transform(yfam)
29
  labelsfam = torch.tensor(encoded_labelsfam)
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
  device
32
 
33
  modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t33_650M_UR50D", num_labels=len(set(labelsfam.tolist())))
34
  modelfam = modelfam.to('cpu')
35
 
36
- modelfam.load_state_dict(torch.load("model_650M.pth", map_location=torch.device('cpu')))
37
  modelfam.eval()
38
 
39
  x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
@@ -74,7 +74,7 @@ device
74
  model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(label_encoder.classes_))
75
  model = model.to('cpu')
76
 
77
- model.load_state_dict(torch.load("best_model_35M_t12_5v5.pth", map_location=torch.device('cpu'))) #model_best_35v2M.pth
78
  model.eval()
79
 
80
  x_test = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
 
27
  label_encoderfam = LabelEncoder()
28
  encoded_labelsfam = label_encoderfam.fit_transform(yfam)
29
  labelsfam = torch.tensor(encoded_labelsfam)
30
+ device = 'cpu'
31
  device
32
 
33
  modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t33_650M_UR50D", num_labels=len(set(labelsfam.tolist())))
34
  modelfam = modelfam.to('cpu')
35
 
36
+ modelfam.load_state_dict(torch.load("model_650M.pth"))
37
  modelfam.eval()
38
 
39
  x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
 
74
  model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(label_encoder.classes_))
75
  model = model.to('cpu')
76
 
77
+ model.load_state_dict(torch.load("best_model_35M_t12_5v5.pth")) #model_best_35v2M.pth
78
  model.eval()
79
 
80
  x_test = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH