Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import transformers
|
3 |
-
from transformers import AutoTokenizer,
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
|
@@ -12,11 +12,10 @@ class LogisticRegressionTorch(nn.Module):
|
|
12 |
output_dim: int):
|
13 |
|
14 |
super(LogisticRegressionTorch, self).__init__()
|
15 |
-
self.batch_norm = nn.BatchNorm1d(num_features
|
16 |
self.linear = nn.Linear(input_dim, output_dim)
|
17 |
|
18 |
def forward(self, x):
|
19 |
-
|
20 |
x = self.batch_norm(x)
|
21 |
out = self.linear(x)
|
22 |
return out
|
@@ -24,7 +23,7 @@ class LogisticRegressionTorch(nn.Module):
|
|
24 |
class BertClassifier(nn.Module):
|
25 |
|
26 |
def __init__(self,
|
27 |
-
bert_model:
|
28 |
classifier: LogisticRegressionTorch,
|
29 |
num_labels: int):
|
30 |
|
@@ -34,13 +33,10 @@ class BertClassifier(nn.Module):
|
|
34 |
self.num_labels = num_labels
|
35 |
|
36 |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
|
37 |
-
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None)
|
38 |
# Extract outputs from the BERT model
|
39 |
-
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
40 |
-
|
41 |
-
# Sanity check for outputs
|
42 |
-
assert 'hidden_states' in outputs, "BERT model output does not contain 'hidden_states'."
|
43 |
-
|
44 |
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
|
45 |
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
46 |
|
@@ -54,12 +50,10 @@ class BertClassifier(nn.Module):
|
|
54 |
loss = None
|
55 |
|
56 |
if labels is not None:
|
57 |
-
|
58 |
loss_fct = nn.CrossEntropyLoss()
|
59 |
pred = logits.view(-1, self.num_labels)
|
60 |
observed = labels.view(-1)
|
61 |
loss = loss_fct(pred, observed)
|
62 |
-
#assert loss_fct(float(observed), observed) < 1e-6
|
63 |
|
64 |
# Return the loss and logits
|
65 |
return loss, logits
|
@@ -71,8 +65,7 @@ class BertClassifier(nn.Module):
|
|
71 |
metadata_features = 0
|
72 |
N_UNIQUE_CLASSES = 38 ## or 38
|
73 |
|
74 |
-
|
75 |
-
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states = True)
|
76 |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
77 |
|
78 |
# Initialize the classifier
|
@@ -80,14 +73,14 @@ input_size = 768 + metadata_features # featurizer output size + metadata size
|
|
80 |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
81 |
|
82 |
# Load Weights
|
83 |
-
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
|
84 |
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
85 |
|
86 |
base_model.load_state_dict(weights['model_state_dict'])
|
87 |
log_reg.load_state_dict(weights['log_reg_state_dict'])
|
88 |
|
89 |
# Creating Model
|
90 |
-
model = BertClassifier(base_model, log_reg, num_labels
|
91 |
|
92 |
|
93 |
# Define a function to process the DNA sequence
|
@@ -96,10 +89,10 @@ def analyze_dna(sequence):
|
|
96 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
97 |
|
98 |
# Get model predictions
|
99 |
-
_,
|
100 |
|
101 |
# Convert logits to probabilities
|
102 |
-
probabilities = torch.nn.functional.softmax(
|
103 |
|
104 |
# Get the top 5 most likely classes
|
105 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
@@ -108,7 +101,7 @@ def analyze_dna(sequence):
|
|
108 |
# Prepare the output as a list of tuples (class_index, probability)
|
109 |
result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
|
110 |
|
111 |
-
return
|
112 |
|
113 |
# Create a Gradio interface
|
114 |
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
|
|
|
1 |
import gradio as gr
|
2 |
import transformers
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
|
|
|
12 |
output_dim: int):
|
13 |
|
14 |
super(LogisticRegressionTorch, self).__init__()
|
15 |
+
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
16 |
self.linear = nn.Linear(input_dim, output_dim)
|
17 |
|
18 |
def forward(self, x):
|
|
|
19 |
x = self.batch_norm(x)
|
20 |
out = self.linear(x)
|
21 |
return out
|
|
|
23 |
class BertClassifier(nn.Module):
|
24 |
|
25 |
def __init__(self,
|
26 |
+
bert_model: AutoModel,
|
27 |
classifier: LogisticRegressionTorch,
|
28 |
num_labels: int):
|
29 |
|
|
|
33 |
self.num_labels = num_labels
|
34 |
|
35 |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
|
36 |
+
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
|
37 |
# Extract outputs from the BERT model
|
38 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
39 |
+
|
|
|
|
|
|
|
40 |
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
|
41 |
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
42 |
|
|
|
50 |
loss = None
|
51 |
|
52 |
if labels is not None:
|
|
|
53 |
loss_fct = nn.CrossEntropyLoss()
|
54 |
pred = logits.view(-1, self.num_labels)
|
55 |
observed = labels.view(-1)
|
56 |
loss = loss_fct(pred, observed)
|
|
|
57 |
|
58 |
# Return the loss and logits
|
59 |
return loss, logits
|
|
|
65 |
metadata_features = 0
|
66 |
N_UNIQUE_CLASSES = 38 ## or 38
|
67 |
|
68 |
+
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
|
|
|
69 |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
70 |
|
71 |
# Initialize the classifier
|
|
|
73 |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
74 |
|
75 |
# Load Weights
|
76 |
+
model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth'
|
77 |
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
78 |
|
79 |
base_model.load_state_dict(weights['model_state_dict'])
|
80 |
log_reg.load_state_dict(weights['log_reg_state_dict'])
|
81 |
|
82 |
# Creating Model
|
83 |
+
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
84 |
|
85 |
|
86 |
# Define a function to process the DNA sequence
|
|
|
89 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
90 |
|
91 |
# Get model predictions
|
92 |
+
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
93 |
|
94 |
# Convert logits to probabilities
|
95 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
|
96 |
|
97 |
# Get the top 5 most likely classes
|
98 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
|
|
101 |
# Prepare the output as a list of tuples (class_index, probability)
|
102 |
result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
|
103 |
|
104 |
+
return result
|
105 |
|
106 |
# Create a Gradio interface
|
107 |
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
|