Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,13 @@ import pandas as pd
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
import base64
|
|
|
10 |
|
11 |
-
#
|
12 |
label_to_int = pd.read_pickle('label_to_int.pkl')
|
13 |
int_to_label = {v: k for k, v in label_to_int.items()}
|
14 |
|
15 |
class LogisticRegressionTorch(nn.Module):
|
16 |
-
|
17 |
def __init__(self, input_dim: int, output_dim: int):
|
18 |
super(LogisticRegressionTorch, self).__init__()
|
19 |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
@@ -25,102 +25,64 @@ class LogisticRegressionTorch(nn.Module):
|
|
25 |
return out
|
26 |
|
27 |
class BertClassifier(nn.Module):
|
28 |
-
|
29 |
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
|
30 |
super(BertClassifier, self).__init__()
|
31 |
-
self.bert = bert_model
|
32 |
self.classifier = classifier
|
33 |
self.num_labels = num_labels
|
34 |
|
35 |
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None
|
36 |
-
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
|
37 |
-
# Extract outputs from the BERT model
|
38 |
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
39 |
-
|
40 |
-
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
|
41 |
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
42 |
-
|
43 |
-
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
|
44 |
-
# to-do later!
|
45 |
-
|
46 |
-
# Pass the pooled output to the classifier to get the logits
|
47 |
logits = self.classifier(pooled_output)
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
if labels is not None:
|
53 |
-
loss_fct = nn.CrossEntropyLoss()
|
54 |
-
pred = logits.view(-1, self.num_labels)
|
55 |
-
observed = labels.view(-1)
|
56 |
-
loss = loss_fct(pred, observed)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
|
61 |
-
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
base_model
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
72 |
|
73 |
-
|
74 |
-
import os
|
75 |
|
76 |
-
|
77 |
-
model_weights_path = os.getenv('MODEL_PATH')
|
78 |
-
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
79 |
-
|
80 |
-
base_model.load_state_dict(weights['model_state_dict'])
|
81 |
-
log_reg.load_state_dict(weights['log_reg_state_dict'])
|
82 |
-
|
83 |
-
# Creating Model
|
84 |
-
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
85 |
-
model.eval()
|
86 |
|
87 |
def analyze_dna(sequence):
|
88 |
try:
|
89 |
-
# Check if the sequence contains only valid characters
|
90 |
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
|
91 |
-
return "Error: Sequence contains invalid characters"
|
92 |
|
93 |
-
# Check if the sequence is at least 300 nucleotides long
|
94 |
if len(sequence) < 300:
|
95 |
-
return "Error: Sequence needs to be at least 300 nucleotides long"
|
96 |
|
97 |
-
# Preprocess the input sequence
|
98 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
|
|
99 |
|
100 |
-
# Get model predictions
|
101 |
-
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
102 |
-
|
103 |
-
# Convert logits to probabilities
|
104 |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
|
105 |
-
|
106 |
-
# Get the top 5 most likely classes
|
107 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
108 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
109 |
-
|
110 |
-
# Map indices to label names
|
111 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
112 |
-
|
113 |
-
# Prepare the output as a list of tuples (label_name, probability)
|
114 |
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
|
115 |
|
116 |
-
# Plot histogram
|
117 |
fig, ax = plt.subplots(figsize=(10, 6))
|
118 |
ax.barh(top_5_labels, top_5_probs, color='skyblue')
|
119 |
ax.set_xlabel('Probability')
|
120 |
ax.set_title('Top 5 Most Likely Labels')
|
121 |
-
plt.gca().invert_yaxis()
|
122 |
|
123 |
-
# Save plot to a PNG image in memory
|
124 |
buf = io.BytesIO()
|
125 |
plt.savefig(buf, format='png')
|
126 |
buf.seek(0)
|
@@ -129,10 +91,8 @@ def analyze_dna(sequence):
|
|
129 |
|
130 |
return result, f'<img src="data:image/png;base64,{image_base64}" />'
|
131 |
|
132 |
-
except
|
133 |
-
# Return the error message
|
134 |
return str(e), ""
|
135 |
-
|
136 |
|
137 |
# Create a Gradio interface
|
138 |
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
import base64
|
10 |
+
import os
|
11 |
|
12 |
+
# Load label mapping
|
13 |
label_to_int = pd.read_pickle('label_to_int.pkl')
|
14 |
int_to_label = {v: k for k, v in label_to_int.items()}
|
15 |
|
16 |
class LogisticRegressionTorch(nn.Module):
|
|
|
17 |
def __init__(self, input_dim: int, output_dim: int):
|
18 |
super(LogisticRegressionTorch, self).__init__()
|
19 |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
|
|
25 |
return out
|
26 |
|
27 |
class BertClassifier(nn.Module):
|
|
|
28 |
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
|
29 |
super(BertClassifier, self).__init__()
|
30 |
+
self.bert = bert_model
|
31 |
self.classifier = classifier
|
32 |
self.num_labels = num_labels
|
33 |
|
34 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
|
|
|
|
|
35 |
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
|
|
36 |
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
|
|
|
|
|
|
|
|
|
|
37 |
logits = self.classifier(pooled_output)
|
38 |
+
return logits
|
39 |
|
40 |
+
def load_model():
|
41 |
+
metadata_features = 0
|
42 |
+
N_UNIQUE_CLASSES = 38
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
46 |
|
47 |
+
input_size = 768 + metadata_features
|
48 |
+
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
49 |
|
50 |
+
model_weights_path = os.getenv('MODEL_PATH')
|
51 |
+
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
52 |
|
53 |
+
base_model.load_state_dict(weights['model_state_dict'])
|
54 |
+
log_reg.load_state_dict(weights['log_reg_state_dict'])
|
55 |
|
56 |
+
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
57 |
+
model.eval()
|
|
|
58 |
|
59 |
+
return model, tokenizer
|
|
|
60 |
|
61 |
+
model, tokenizer = load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def analyze_dna(sequence):
|
64 |
try:
|
|
|
65 |
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
|
66 |
+
return "Error: Sequence contains invalid characters", ""
|
67 |
|
|
|
68 |
if len(sequence) < 300:
|
69 |
+
return "Error: Sequence needs to be at least 300 nucleotides long", ""
|
70 |
|
|
|
71 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
72 |
+
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
73 |
|
|
|
|
|
|
|
|
|
74 |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
|
|
|
|
|
75 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
76 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
|
|
|
|
77 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
|
|
|
|
78 |
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
|
79 |
|
|
|
80 |
fig, ax = plt.subplots(figsize=(10, 6))
|
81 |
ax.barh(top_5_labels, top_5_probs, color='skyblue')
|
82 |
ax.set_xlabel('Probability')
|
83 |
ax.set_title('Top 5 Most Likely Labels')
|
84 |
+
plt.gca().invert_yaxis()
|
85 |
|
|
|
86 |
buf = io.BytesIO()
|
87 |
plt.savefig(buf, format='png')
|
88 |
buf.seek(0)
|
|
|
91 |
|
92 |
return result, f'<img src="data:image/png;base64,{image_base64}" />'
|
93 |
|
94 |
+
except Exception as e:
|
|
|
95 |
return str(e), ""
|
|
|
96 |
|
97 |
# Create a Gradio interface
|
98 |
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
|