Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,93 @@
|
|
1 |
import gradio as gr
|
2 |
import transformers
|
3 |
-
from transformers import AutoTokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Load the Hugging Face model and tokenizer
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Define a function to process the DNA sequence
|
11 |
def analyze_dna(sequence):
|
|
|
1 |
import gradio as gr
|
2 |
import transformers
|
3 |
+
from transformers import AutoTokenizer, AutoModel,BertModel, BertTokenizer
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class LogisticRegressionTorch(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
input_dim: int,
|
12 |
+
output_dim: int):
|
13 |
+
|
14 |
+
super(LogisticRegressionTorch, self).__init__()
|
15 |
+
self.batch_norm = nn.BatchNorm1d(num_features = input_dim)
|
16 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
|
20 |
+
x = self.batch_norm(x)
|
21 |
+
out = self.linear(x)
|
22 |
+
return out
|
23 |
+
|
24 |
+
class BertClassifier(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
bert_model: BertForMaskedLM,
|
28 |
+
classifier: LogisticRegressionTorch,
|
29 |
+
num_labels: int):
|
30 |
+
|
31 |
+
super(BertClassifier, self).__init__()
|
32 |
+
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
|
33 |
+
self.classifier = classifier
|
34 |
+
self.num_labels = num_labels
|
35 |
+
|
36 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
|
37 |
+
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
38 |
+
# Extract outputs from the BERT model
|
39 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
40 |
+
|
41 |
+
# Sanity check for outputs
|
42 |
+
assert 'hidden_states' in outputs, "BERT model output does not contain 'hidden_states'."
|
43 |
+
|
44 |
+
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
|
45 |
+
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
46 |
+
|
47 |
+
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
|
48 |
+
# to-do later!
|
49 |
+
|
50 |
+
# Pass the pooled output to the classifier to get the logits
|
51 |
+
logits = self.classifier(pooled_output)
|
52 |
+
|
53 |
+
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
|
54 |
+
loss = None
|
55 |
+
|
56 |
+
if labels is not None:
|
57 |
+
|
58 |
+
loss_fct = nn.CrossEntropyLoss()
|
59 |
+
pred = logits.view(-1, self.num_labels)
|
60 |
+
observed = labels.view(-1)
|
61 |
+
loss = loss_fct(pred, observed)
|
62 |
+
#assert loss_fct(float(observed), observed) < 1e-6
|
63 |
+
|
64 |
+
# Return the loss and logits
|
65 |
+
return loss, logits
|
66 |
+
|
67 |
+
|
68 |
|
69 |
# Load the Hugging Face model and tokenizer
|
70 |
+
import torch.nn as nn
|
71 |
+
from transformers import AutoTokenizer
|
72 |
+
|
73 |
+
metadata_features = 0
|
74 |
+
N_UNIQUE_CLASSES = 38 ## or 38
|
75 |
+
|
76 |
+
|
77 |
+
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states = True)
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
79 |
+
|
80 |
+
# Initialize the classifier
|
81 |
+
input_size = 768 + metadata_features # featurizer output size + metadata size
|
82 |
+
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
83 |
+
|
84 |
+
# Load Weights
|
85 |
+
model_weights_path = '/your_model_weights.pth'
|
86 |
+
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
87 |
+
|
88 |
+
model = BertClassifier(base_model, log_reg, num_labels = N_UNIQUE_CLASSES)
|
89 |
+
base_model.load_state_dict(weights['model_state_dict'])
|
90 |
+
log_reg.load_state_dict(weights['log_reg_state_dict'])
|
91 |
|
92 |
# Define a function to process the DNA sequence
|
93 |
def analyze_dna(sequence):
|