mawairon commited on
Commit
3f3c29c
·
verified ·
1 Parent(s): 112178f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -4
app.py CHANGED
@@ -1,11 +1,93 @@
1
  import gradio as gr
2
  import transformers
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load the Hugging Face model and tokenizer
6
- model_name = 'AIRI-Institute/gena-lm-bert-base-lastln-t2t' # Replace with the actual model name
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Define a function to process the DNA sequence
11
  def analyze_dna(sequence):
 
1
  import gradio as gr
2
  import transformers
3
+ from transformers import AutoTokenizer, AutoModel,BertModel, BertTokenizer
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class LogisticRegressionTorch(nn.Module):
9
+
10
+ def __init__(self,
11
+ input_dim: int,
12
+ output_dim: int):
13
+
14
+ super(LogisticRegressionTorch, self).__init__()
15
+ self.batch_norm = nn.BatchNorm1d(num_features = input_dim)
16
+ self.linear = nn.Linear(input_dim, output_dim)
17
+
18
+ def forward(self, x):
19
+
20
+ x = self.batch_norm(x)
21
+ out = self.linear(x)
22
+ return out
23
+
24
+ class BertClassifier(nn.Module):
25
+
26
+ def __init__(self,
27
+ bert_model: BertForMaskedLM,
28
+ classifier: LogisticRegressionTorch,
29
+ num_labels: int):
30
+
31
+ super(BertClassifier, self).__init__()
32
+ self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
33
+ self.classifier = classifier
34
+ self.num_labels = num_labels
35
+
36
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
37
+ token_type_ids: torch.Tensor = None, labels: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
38
+ # Extract outputs from the BERT model
39
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
40
+
41
+ # Sanity check for outputs
42
+ assert 'hidden_states' in outputs, "BERT model output does not contain 'hidden_states'."
43
+
44
+ # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
45
+ pooled_output = outputs.hidden_states[-1][:, 0, :]
46
+
47
+ assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
48
+ # to-do later!
49
+
50
+ # Pass the pooled output to the classifier to get the logits
51
+ logits = self.classifier(pooled_output)
52
+
53
+ # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
54
+ loss = None
55
+
56
+ if labels is not None:
57
+
58
+ loss_fct = nn.CrossEntropyLoss()
59
+ pred = logits.view(-1, self.num_labels)
60
+ observed = labels.view(-1)
61
+ loss = loss_fct(pred, observed)
62
+ #assert loss_fct(float(observed), observed) < 1e-6
63
+
64
+ # Return the loss and logits
65
+ return loss, logits
66
+
67
+
68
 
69
  # Load the Hugging Face model and tokenizer
70
+ import torch.nn as nn
71
+ from transformers import AutoTokenizer
72
+
73
+ metadata_features = 0
74
+ N_UNIQUE_CLASSES = 38 ## or 38
75
+
76
+
77
+ base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states = True)
78
+ tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
79
+
80
+ # Initialize the classifier
81
+ input_size = 768 + metadata_features # featurizer output size + metadata size
82
+ log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
83
+
84
+ # Load Weights
85
+ model_weights_path = '/your_model_weights.pth'
86
+ weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
87
+
88
+ model = BertClassifier(base_model, log_reg, num_labels = N_UNIQUE_CLASSES)
89
+ base_model.load_state_dict(weights['model_state_dict'])
90
+ log_reg.load_state_dict(weights['log_reg_state_dict'])
91
 
92
  # Define a function to process the DNA sequence
93
  def analyze_dna(sequence):