mawairon commited on
Commit
66990c3
·
verified ·
1 Parent(s): dc7d693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -3,17 +3,18 @@ import transformers
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
4
  import torch
5
  import torch.nn as nn
6
- import matplotlib.pyplot as plt
7
  import pandas as pd
 
 
 
8
 
9
-
 
 
10
 
11
  class LogisticRegressionTorch(nn.Module):
12
 
13
- def __init__(self,
14
- input_dim: int,
15
- output_dim: int):
16
-
17
  super(LogisticRegressionTorch, self).__init__()
18
  self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
19
  self.linear = nn.Linear(input_dim, output_dim)
@@ -25,11 +26,7 @@ class LogisticRegressionTorch(nn.Module):
25
 
26
  class BertClassifier(nn.Module):
27
 
28
- def __init__(self,
29
- bert_model: AutoModel,
30
- classifier: LogisticRegressionTorch,
31
- num_labels: int):
32
-
33
  super(BertClassifier, self).__init__()
34
  self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
35
  self.classifier = classifier
@@ -61,22 +58,20 @@ class BertClassifier(nn.Module):
61
  # Return the loss and logits
62
  return loss, logits
63
 
64
-
65
-
66
  # Load the Hugging Face model and tokenizer
67
 
68
  metadata_features = 0
69
- N_UNIQUE_CLASSES = 38 ## or 38
70
 
71
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
72
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
73
 
74
  # Initialize the classifier
75
- input_size = 768 + metadata_features # featurizer output size + metadata size
76
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
77
 
78
  # Load Weights
79
- model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
80
  weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
81
 
82
  base_model.load_state_dict(weights['model_state_dict'])
@@ -84,11 +79,6 @@ log_reg.load_state_dict(weights['log_reg_state_dict'])
84
 
85
  # Creating Model
86
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
87
- model.eval()
88
-
89
- # Dictionary to decode model predictions
90
- label_to_int = pd.read_pkl('label_to_int.pkl')
91
- int_to_label = {v: k for k, v in label_to_int.items()}
92
 
93
  # Define a function to process the DNA sequence
94
  def analyze_dna(sequence):
@@ -113,20 +103,27 @@ def analyze_dna(sequence):
113
  top_5_labels = [int_to_label[i] for i in top_5_indices]
114
 
115
  # Prepare the output as a list of tuples (label_name, probability)
116
- #result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
 
117
  # Plot histogram
 
118
  fig, ax = plt.subplots(figsize=(10, 6))
119
  ax.barh(top_5_labels, top_5_probs, color='skyblue')
120
  ax.set_xlabel('Probability')
121
  ax.set_title('Top 5 Most Likely Labels')
122
  plt.gca().invert_yaxis() # Highest probabilities at the top
123
 
124
- #return result
 
 
 
 
 
 
 
125
 
126
  # Create a Gradio interface
127
- demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
128
 
129
  # Launch the interface
130
  demo.launch()
131
-
132
-
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
4
  import torch
5
  import torch.nn as nn
 
6
  import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+ import base64
10
 
11
+ # Assuming label_to_int is a dictionary with {label_name: label_index}
12
+ label_to_int = pd.read_pickle('label_to_int.pkl')
13
+ int_to_label = {v: k for k, v in label_to_int.items()}
14
 
15
  class LogisticRegressionTorch(nn.Module):
16
 
17
+ def __init__(self, input_dim: int, output_dim: int):
 
 
 
18
  super(LogisticRegressionTorch, self).__init__()
19
  self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
20
  self.linear = nn.Linear(input_dim, output_dim)
 
26
 
27
  class BertClassifier(nn.Module):
28
 
29
+ def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
 
 
 
 
30
  super(BertClassifier, self).__init__()
31
  self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
32
  self.classifier = classifier
 
58
  # Return the loss and logits
59
  return loss, logits
60
 
 
 
61
  # Load the Hugging Face model and tokenizer
62
 
63
  metadata_features = 0
64
+ N_UNIQUE_CLASSES = 38 # or 38
65
 
66
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
67
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
68
 
69
  # Initialize the classifier
70
+ input_size = 768 + metadata_features # featurizer output size + metadata size
71
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
72
 
73
  # Load Weights
74
+ model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth'
75
  weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
76
 
77
  base_model.load_state_dict(weights['model_state_dict'])
 
79
 
80
  # Creating Model
81
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
 
 
 
 
 
82
 
83
  # Define a function to process the DNA sequence
84
  def analyze_dna(sequence):
 
103
  top_5_labels = [int_to_label[i] for i in top_5_indices]
104
 
105
  # Prepare the output as a list of tuples (label_name, probability)
106
+ result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
107
+
108
  # Plot histogram
109
+
110
  fig, ax = plt.subplots(figsize=(10, 6))
111
  ax.barh(top_5_labels, top_5_probs, color='skyblue')
112
  ax.set_xlabel('Probability')
113
  ax.set_title('Top 5 Most Likely Labels')
114
  plt.gca().invert_yaxis() # Highest probabilities at the top
115
 
116
+ # Save plot to a PNG image in memory
117
+ buf = io.BytesIO()
118
+ plt.savefig(buf, format='png')
119
+ buf.seek(0)
120
+ image_base64 = base64.b64encode(buf.read()).decode('utf-8')
121
+ buf.close()
122
+
123
+ return result, f'<img src="data:image/png;base64,{image_base64}" />'
124
 
125
  # Create a Gradio interface
126
+ demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
127
 
128
  # Launch the interface
129
  demo.launch()