mawairon commited on
Commit
ace289f
1 Parent(s): 1bb2663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -73
app.py CHANGED
@@ -10,6 +10,10 @@ import base64
10
  import os
11
  import huggingface_hub
12
  from huggingface_hub import hf_hub_download, login
 
 
 
 
13
 
14
  # Load label mapping
15
  label_to_int = pd.read_pickle('label_to_int.pkl')
@@ -24,63 +28,67 @@ for k, v in int_to_label.items():
24
  elif "RUSSIAN" in v:
25
  int_to_label[k] = "RUSSIA"
26
 
27
- class LogisticRegressionTorch(nn.Module):
28
- def __init__(self, input_dim: int, output_dim: int):
29
- super(LogisticRegressionTorch, self).__init__()
30
- self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
31
- self.linear = nn.Linear(input_dim, output_dim)
32
-
33
- def forward(self, x):
34
- x = self.batch_norm(x)
35
- out = self.linear(x)
36
- return out
37
-
38
- class BertClassifier(nn.Module):
39
- def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
40
- super(BertClassifier, self).__init__()
41
- self.bert = bert_model
42
- self.classifier = classifier
43
- self.num_labels = num_labels
44
-
45
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
46
- outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
47
- pooled_output = outputs.hidden_states[-1][:, 0, :]
48
- logits = self.classifier(pooled_output)
49
- return logits
50
-
51
- def load_model():
52
- metadata_features = 0
53
- N_UNIQUE_CLASSES = 38
54
-
55
- base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
56
- tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
57
-
58
- input_size = 768 + metadata_features
59
- log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
60
-
61
- token = os.getenv('HUGGINGFACE_TOKEN')
62
- if token is None:
63
- raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
64
 
65
- login(token=token)
66
- file_path = hf_hub_download(
67
- repo_id="mawairon/noo_test",
68
- filename="gena-blastln-bs33-lr4e-05-S168.pth",
69
- use_auth_token=token
70
- )
71
- weights = torch.load(file_path, map_location=torch.device('cpu'))
72
 
73
- base_model.load_state_dict(weights['model_state_dict'])
74
- log_reg.load_state_dict(weights['log_reg_state_dict'])
75
-
76
- model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
77
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- return model, tokenizer
80
 
81
- model, tokenizer = load_model()
82
 
83
- def analyze_dna(username, password, sequence):
84
 
85
  valid_usernames = os.getenv('USERNAME').split(',')
86
  env_password = os.getenv('PASSWORD')
@@ -89,6 +97,7 @@ def analyze_dna(username, password, sequence):
89
  return {"error": "Invalid username or password"}, ""
90
 
91
  try:
 
92
  # Remove all whitespace characters
93
  sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
94
 
@@ -98,25 +107,43 @@ def analyze_dna(username, password, sequence):
98
  if len(sequence) < 300:
99
  return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
100
 
101
- def get_logits(seq):
102
- inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
103
- with torch.no_grad():
104
- logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
105
- return logits
106
-
107
- if len(sequence) > 3000:
108
- num_shifts = len(sequence) // 1000
109
- logits_sum = None
110
- for i in range(num_shifts):
111
- shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
112
- logits = get_logits(shifted_sequence)
113
- if logits_sum is None:
114
- logits_sum = logits
115
- else:
116
- logits_sum += logits
117
- logits_avg = logits_sum / num_shifts
118
- else:
119
- logits_avg = get_logits(sequence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
122
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
@@ -147,9 +174,13 @@ demo = gr.Interface(
147
  inputs=[
148
  gr.Textbox(label="Username"),
149
  gr.Textbox(label="Password", type="password"),
150
- gr.Textbox(label="DNA Sequence")
 
 
 
 
151
  ],
152
- outputs=["json", "html"]
153
  )
154
 
155
  # Launch the interface
 
10
  import os
11
  import huggingface_hub
12
  from huggingface_hub import hf_hub_download, login
13
+ import model_archs
14
+ from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
15
+ import tangermeme
16
+ from tangermeme import one_hot_encode
17
 
18
  # Load label mapping
19
  label_to_int = pd.read_pickle('label_to_int.pkl')
 
28
  elif "RUSSIAN" in v:
29
  int_to_label[k] = "RUSSIA"
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
32
 
33
+ def load_model(model_name: str):
34
+ metadata_features = 0
35
+ N_UNIQUE_CLASSES = 38
36
+
37
+ if model_name == 'gena-bert':
38
+ base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
39
+ tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
40
+
41
+ input_size = 768 + metadata_features
42
+ log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
43
+
44
+ token = os.getenv('HUGGINGFACE_TOKEN')
45
+ if token is None:
46
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
47
+
48
+ login(token=token)
49
+ file_path = hf_hub_download(
50
+ repo_id="mawairon/noo_test",
51
+ filename="gena-blastln-bs33-lr4e-05-S168.pth",
52
+ use_auth_token=token
53
+ )
54
+ weights = torch.load(file_path, map_location=torch.device('cpu'))
55
+
56
+ base_model.load_state_dict(weights['model_state_dict'])
57
+ log_reg.load_state_dict(weights['log_reg_state_dict'])
58
+
59
+ model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
60
+ model.eval()
61
+
62
+ return model, tokenizer
63
+
64
+ elif model_name == 'CNN':
65
+ hidden_dim = 2048
66
+ width = 2048
67
+ seq_drop_prob = 0.05
68
+ train_sequence_length = 8000
69
+ weight_decay = 0.0001
70
+ num_labs = len(set(y_train))
71
+
72
+
73
+ model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
74
+ new_head = torch.nn.Sequential(
75
+ torch.nn.Dropout(0.5),
76
+ MLP([hidden_dim*2 , num_labs])
77
+ )
78
+
79
+ model = torch.nn.Sequential(
80
+ model_seq,
81
+ new_head
82
+ )
83
+ return model, None
84
+
85
+ else:
86
+ return {"error": "Invalid model name"}
87
+
88
 
 
89
 
 
90
 
91
+ def analyze_dna(username, password, sequence, model_name):
92
 
93
  valid_usernames = os.getenv('USERNAME').split(',')
94
  env_password = os.getenv('PASSWORD')
 
97
  return {"error": "Invalid username or password"}, ""
98
 
99
  try:
100
+
101
  # Remove all whitespace characters
102
  sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
103
 
 
107
  if len(sequence) < 300:
108
  return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
109
 
110
+ model, tokenizer = load_model(model_name)
111
+
112
+ def get_logits(seq, model_name):
113
+ if model_name == 'gena-bert':
114
+ inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
115
+ with torch.no_grad():
116
+ logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
117
+ return logits
118
+
119
+ elif model_name == 'CNN':
120
+ # Truncate sequence
121
+ SEQUENCE_LENGTH = 8000
122
+ seq = seq[:SEQUENCE_LENGTH]
123
+
124
+ # Pad sequences to the desired length
125
+ seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH]
126
+
127
+ # Apply one-hot encoding to the 'sequence' column
128
+ input = seq.one_hot_encode()
129
+ with torch.no_grad():
130
+ logits = model(input)
131
+ return logits
132
+
133
+
134
+ # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
135
+ # num_shifts = len(sequence) // 1000
136
+ # logits_sum = None
137
+ # for i in range(num_shifts):
138
+ # shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
139
+ # logits = get_logits(shifted_sequence)
140
+ # if logits_sum is None:
141
+ # logits_sum = logits
142
+ # else:
143
+ # logits_sum += logits
144
+ # logits_avg = logits_sum / num_shifts
145
+ # else:
146
+ logits_avg = get_logits(sequence)
147
 
148
  probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
149
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
 
174
  inputs=[
175
  gr.Textbox(label="Username"),
176
  gr.Textbox(label="Password", type="password"),
177
+ gr.Textbox(label="DNA Sequence"),
178
+ gr.Dropdown(label="Model", choices=[
179
+ "gena-bert",
180
+ "CNN"
181
+ ])
182
  ],
183
+ outputs=["json", "HTML"]
184
  )
185
 
186
  # Launch the interface