mawairon commited on
Commit
996a1ec
1 Parent(s): 513b115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -19
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import transformers
3
- from transformers import AutoTokenizer, AutoModel, BertModel, BertForMaskedLM, BertTokenizer
4
  import torch
5
  import torch.nn as nn
6
 
@@ -12,11 +12,10 @@ class LogisticRegressionTorch(nn.Module):
12
  output_dim: int):
13
 
14
  super(LogisticRegressionTorch, self).__init__()
15
- self.batch_norm = nn.BatchNorm1d(num_features = input_dim)
16
  self.linear = nn.Linear(input_dim, output_dim)
17
 
18
  def forward(self, x):
19
-
20
  x = self.batch_norm(x)
21
  out = self.linear(x)
22
  return out
@@ -24,7 +23,7 @@ class LogisticRegressionTorch(nn.Module):
24
  class BertClassifier(nn.Module):
25
 
26
  def __init__(self,
27
- bert_model: BertForMaskedLM,
28
  classifier: LogisticRegressionTorch,
29
  num_labels: int):
30
 
@@ -34,13 +33,10 @@ class BertClassifier(nn.Module):
34
  self.num_labels = num_labels
35
 
36
  def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
37
- token_type_ids: torch.Tensor = None, labels: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
38
  # Extract outputs from the BERT model
39
- outputs = self.bert(input_ids, attention_mask=attention_mask)
40
-
41
- # Sanity check for outputs
42
- assert 'hidden_states' in outputs, "BERT model output does not contain 'hidden_states'."
43
-
44
  # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
45
  pooled_output = outputs.hidden_states[-1][:, 0, :]
46
 
@@ -54,12 +50,10 @@ class BertClassifier(nn.Module):
54
  loss = None
55
 
56
  if labels is not None:
57
-
58
  loss_fct = nn.CrossEntropyLoss()
59
  pred = logits.view(-1, self.num_labels)
60
  observed = labels.view(-1)
61
  loss = loss_fct(pred, observed)
62
- #assert loss_fct(float(observed), observed) < 1e-6
63
 
64
  # Return the loss and logits
65
  return loss, logits
@@ -71,8 +65,7 @@ class BertClassifier(nn.Module):
71
  metadata_features = 0
72
  N_UNIQUE_CLASSES = 38 ## or 38
73
 
74
-
75
- base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states = True)
76
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
77
 
78
  # Initialize the classifier
@@ -80,14 +73,14 @@ input_size = 768 + metadata_features # featurizer output size + metadata size
80
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
81
 
82
  # Load Weights
83
- model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
84
  weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
85
 
86
  base_model.load_state_dict(weights['model_state_dict'])
87
  log_reg.load_state_dict(weights['log_reg_state_dict'])
88
 
89
  # Creating Model
90
- model = BertClassifier(base_model, log_reg, num_labels = N_UNIQUE_CLASSES)
91
 
92
 
93
  # Define a function to process the DNA sequence
@@ -96,10 +89,10 @@ def analyze_dna(sequence):
96
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
97
 
98
  # Get model predictions
99
- _, outputs = model(**inputs)
100
 
101
  # Convert logits to probabilities
102
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
103
 
104
  # Get the top 5 most likely classes
105
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
@@ -108,7 +101,7 @@ def analyze_dna(sequence):
108
  # Prepare the output as a list of tuples (class_index, probability)
109
  result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
110
 
111
- return probabilities
112
 
113
  # Create a Gradio interface
114
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
 
1
  import gradio as gr
2
  import transformers
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
4
  import torch
5
  import torch.nn as nn
6
 
 
12
  output_dim: int):
13
 
14
  super(LogisticRegressionTorch, self).__init__()
15
+ self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
16
  self.linear = nn.Linear(input_dim, output_dim)
17
 
18
  def forward(self, x):
 
19
  x = self.batch_norm(x)
20
  out = self.linear(x)
21
  return out
 
23
  class BertClassifier(nn.Module):
24
 
25
  def __init__(self,
26
+ bert_model: AutoModel,
27
  classifier: LogisticRegressionTorch,
28
  num_labels: int):
29
 
 
33
  self.num_labels = num_labels
34
 
35
  def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
36
+ token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
37
  # Extract outputs from the BERT model
38
+ outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
39
+
 
 
 
40
  # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
41
  pooled_output = outputs.hidden_states[-1][:, 0, :]
42
 
 
50
  loss = None
51
 
52
  if labels is not None:
 
53
  loss_fct = nn.CrossEntropyLoss()
54
  pred = logits.view(-1, self.num_labels)
55
  observed = labels.view(-1)
56
  loss = loss_fct(pred, observed)
 
57
 
58
  # Return the loss and logits
59
  return loss, logits
 
65
  metadata_features = 0
66
  N_UNIQUE_CLASSES = 38 ## or 38
67
 
68
+ base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
 
69
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
70
 
71
  # Initialize the classifier
 
73
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
74
 
75
  # Load Weights
76
+ model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth'
77
  weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
78
 
79
  base_model.load_state_dict(weights['model_state_dict'])
80
  log_reg.load_state_dict(weights['log_reg_state_dict'])
81
 
82
  # Creating Model
83
+ model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
84
 
85
 
86
  # Define a function to process the DNA sequence
 
89
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
90
 
91
  # Get model predictions
92
+ _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
93
 
94
  # Convert logits to probabilities
95
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
96
 
97
  # Get the top 5 most likely classes
98
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
 
101
  # Prepare the output as a list of tuples (class_index, probability)
102
  result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
103
 
104
+ return result
105
 
106
  # Create a Gradio interface
107
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")