arikat commited on
Commit
a1a7f2c
1 Parent(s): 5371b1f

new family model

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -20,20 +20,18 @@ import zipfile
20
  import os
21
 
22
  # Load the model from the file
23
- with open('family_labels.pkl', 'rb') as filefam:
24
  yfam = pickle.load(filefam)
25
 
26
  tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
27
- label_encoderfam = LabelEncoder()
28
- encoded_labelsfam = label_encoderfam.fit_transform(yfam)
29
- labelsfam = torch.tensor(encoded_labelsfam)
30
- device = 'cpu'
31
  device
32
 
33
- modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(set(labelsfam.tolist())))
34
  modelfam = modelfam.to('cpu')
35
 
36
- modelfam.load_state_dict(torch.load("family.pth", map_location=torch.device('cpu')))
37
  modelfam.eval()
38
 
39
  x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
@@ -55,7 +53,7 @@ with torch.no_grad():
55
  _, predicted_labelsfam = torch.max(logitsfam, dim=1)
56
  probabilitiesfam[0]
57
 
58
- decoded_labelsfam = label_encoderfam.inverse_transform(predicted_labelsfam.tolist())
59
  decoded_labelsfam
60
 
61
 
 
20
  import os
21
 
22
  # Load the model from the file
23
+ with open('/home/aarya/Documents/paper3/family_labels.pkl', 'rb') as filefam:
24
  yfam = pickle.load(filefam)
25
 
26
  tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
27
+
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
29
  device
30
 
31
+ modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(yfam.classes_))
32
  modelfam = modelfam.to('cpu')
33
 
34
+ modelfam.load_state_dict(torch.load("/home/aarya/Documents/paper3/family.pth"))
35
  modelfam.eval()
36
 
37
  x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
 
53
  _, predicted_labelsfam = torch.max(logitsfam, dim=1)
54
  probabilitiesfam[0]
55
 
56
+ decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist())
57
  decoded_labelsfam
58
 
59