mawairon commited on
Commit
9bf3f2b
1 Parent(s): a91155d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -65
app.py CHANGED
@@ -7,13 +7,13 @@ import pandas as pd
7
  import matplotlib.pyplot as plt
8
  import io
9
  import base64
 
10
 
11
- # Assuming label_to_int is a dictionary with {label_name: label_index}
12
  label_to_int = pd.read_pickle('label_to_int.pkl')
13
  int_to_label = {v: k for k, v in label_to_int.items()}
14
 
15
  class LogisticRegressionTorch(nn.Module):
16
-
17
  def __init__(self, input_dim: int, output_dim: int):
18
  super(LogisticRegressionTorch, self).__init__()
19
  self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
@@ -25,102 +25,64 @@ class LogisticRegressionTorch(nn.Module):
25
  return out
26
 
27
  class BertClassifier(nn.Module):
28
-
29
  def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
30
  super(BertClassifier, self).__init__()
31
- self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
32
  self.classifier = classifier
33
  self.num_labels = num_labels
34
 
35
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
36
- token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
37
- # Extract outputs from the BERT model
38
  outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
39
-
40
- # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
41
  pooled_output = outputs.hidden_states[-1][:, 0, :]
42
-
43
- assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
44
- # to-do later!
45
-
46
- # Pass the pooled output to the classifier to get the logits
47
  logits = self.classifier(pooled_output)
 
48
 
49
- # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
50
- loss = None
51
-
52
- if labels is not None:
53
- loss_fct = nn.CrossEntropyLoss()
54
- pred = logits.view(-1, self.num_labels)
55
- observed = labels.view(-1)
56
- loss = loss_fct(pred, observed)
57
 
58
- # Return the loss and logits
59
- return loss, logits
60
 
61
- # Load the Hugging Face model and tokenizer
 
62
 
63
- metadata_features = 0
64
- N_UNIQUE_CLASSES = 38
65
 
66
- base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
67
- tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
68
 
69
- # Initialize the classifier
70
- input_size = 768 + metadata_features # featurizer output size + metadata size
71
- log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
72
 
73
- # Load Weights
74
- import os
75
 
76
- # Get the model path from the environment variable
77
- model_weights_path = os.getenv('MODEL_PATH')
78
- weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
79
-
80
- base_model.load_state_dict(weights['model_state_dict'])
81
- log_reg.load_state_dict(weights['log_reg_state_dict'])
82
-
83
- # Creating Model
84
- model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
85
- model.eval()
86
 
87
  def analyze_dna(sequence):
88
  try:
89
- # Check if the sequence contains only valid characters
90
  if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
91
- return "Error: Sequence contains invalid characters"
92
 
93
- # Check if the sequence is at least 300 nucleotides long
94
  if len(sequence) < 300:
95
- return "Error: Sequence needs to be at least 300 nucleotides long"
96
 
97
- # Preprocess the input sequence
98
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
 
99
 
100
- # Get model predictions
101
- _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
102
-
103
- # Convert logits to probabilities
104
  probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
105
-
106
- # Get the top 5 most likely classes
107
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
108
  top_5_probs = [probabilities[i] for i in top_5_indices]
109
-
110
- # Map indices to label names
111
  top_5_labels = [int_to_label[i] for i in top_5_indices]
112
-
113
- # Prepare the output as a list of tuples (label_name, probability)
114
  result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
115
 
116
- # Plot histogram
117
  fig, ax = plt.subplots(figsize=(10, 6))
118
  ax.barh(top_5_labels, top_5_probs, color='skyblue')
119
  ax.set_xlabel('Probability')
120
  ax.set_title('Top 5 Most Likely Labels')
121
- plt.gca().invert_yaxis() # Highest probabilities at the top
122
 
123
- # Save plot to a PNG image in memory
124
  buf = io.BytesIO()
125
  plt.savefig(buf, format='png')
126
  buf.seek(0)
@@ -129,10 +91,8 @@ def analyze_dna(sequence):
129
 
130
  return result, f'<img src="data:image/png;base64,{image_base64}" />'
131
 
132
- except ValueError as e:
133
- # Return the error message
134
  return str(e), ""
135
-
136
 
137
  # Create a Gradio interface
138
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
 
7
  import matplotlib.pyplot as plt
8
  import io
9
  import base64
10
+ import os
11
 
12
+ # Load label mapping
13
  label_to_int = pd.read_pickle('label_to_int.pkl')
14
  int_to_label = {v: k for k, v in label_to_int.items()}
15
 
16
  class LogisticRegressionTorch(nn.Module):
 
17
  def __init__(self, input_dim: int, output_dim: int):
18
  super(LogisticRegressionTorch, self).__init__()
19
  self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
 
25
  return out
26
 
27
  class BertClassifier(nn.Module):
 
28
  def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
29
  super(BertClassifier, self).__init__()
30
+ self.bert = bert_model
31
  self.classifier = classifier
32
  self.num_labels = num_labels
33
 
34
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
 
 
35
  outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
 
 
36
  pooled_output = outputs.hidden_states[-1][:, 0, :]
 
 
 
 
 
37
  logits = self.classifier(pooled_output)
38
+ return logits
39
 
40
+ def load_model():
41
+ metadata_features = 0
42
+ N_UNIQUE_CLASSES = 38
 
 
 
 
 
43
 
44
+ base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
45
+ tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
46
 
47
+ input_size = 768 + metadata_features
48
+ log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
49
 
50
+ model_weights_path = os.getenv('MODEL_PATH')
51
+ weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
52
 
53
+ base_model.load_state_dict(weights['model_state_dict'])
54
+ log_reg.load_state_dict(weights['log_reg_state_dict'])
55
 
56
+ model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
57
+ model.eval()
 
58
 
59
+ return model, tokenizer
 
60
 
61
+ model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
62
 
63
  def analyze_dna(sequence):
64
  try:
 
65
  if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
66
+ return "Error: Sequence contains invalid characters", ""
67
 
 
68
  if len(sequence) < 300:
69
+ return "Error: Sequence needs to be at least 300 nucleotides long", ""
70
 
 
71
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
72
+ logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
73
 
 
 
 
 
74
  probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
 
 
75
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
76
  top_5_probs = [probabilities[i] for i in top_5_indices]
 
 
77
  top_5_labels = [int_to_label[i] for i in top_5_indices]
 
 
78
  result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
79
 
 
80
  fig, ax = plt.subplots(figsize=(10, 6))
81
  ax.barh(top_5_labels, top_5_probs, color='skyblue')
82
  ax.set_xlabel('Probability')
83
  ax.set_title('Top 5 Most Likely Labels')
84
+ plt.gca().invert_yaxis()
85
 
 
86
  buf = io.BytesIO()
87
  plt.savefig(buf, format='png')
88
  buf.seek(0)
 
91
 
92
  return result, f'<img src="data:image/png;base64,{image_base64}" />'
93
 
94
+ except Exception as e:
 
95
  return str(e), ""
 
96
 
97
  # Create a Gradio interface
98
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])