mawairon commited on
Commit
356d0ee
1 Parent(s): ba0faf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -61,7 +61,7 @@ class BertClassifier(nn.Module):
61
  # Load the Hugging Face model and tokenizer
62
 
63
  metadata_features = 0
64
- N_UNIQUE_CLASSES = 38 # or 38
65
 
66
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
67
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
@@ -84,51 +84,55 @@ log_reg.load_state_dict(weights['log_reg_state_dict'])
84
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
85
  model.eval()
86
 
87
- # Define a function to process the DNA sequence
88
  def analyze_dna(sequence):
 
 
 
 
89
 
90
- assert all(nucleotide in 'ACTGN' for nucleotide in sequence), "Sequence contains invalid characters"
91
- assert len(sequence) >= 300, "Sequence needs to be at least 300 nucleotides long"
92
-
93
- # Preprocess the input sequence
94
- inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
95
-
96
- print("Tokenization done.")
97
- # Get model predictions
98
- _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
99
-
100
- print("Forward pass done.")
101
-
102
- # Convert logits to probabilities
103
- probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
104
-
105
- print("Probabilities done.")
106
- # Get the top 5 most likely classes
107
- top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
108
- top_5_probs = [probabilities[i] for i in top_5_indices]
109
-
110
- # Map indices to label names
111
- top_5_labels = [int_to_label[i] for i in top_5_indices]
112
-
113
- # Prepare the output as a list of tuples (label_name, probability)
114
- result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
115
-
116
- # Plot histogram
117
-
118
- fig, ax = plt.subplots(figsize=(10, 6))
119
- ax.barh(top_5_labels, top_5_probs, color='skyblue')
120
- ax.set_xlabel('Probability')
121
- ax.set_title('Top 5 Most Likely Labels')
122
- plt.gca().invert_yaxis() # Highest probabilities at the top
123
-
124
- # Save plot to a PNG image in memory
125
- buf = io.BytesIO()
126
- plt.savefig(buf, format='png')
127
- buf.seek(0)
128
- image_base64 = base64.b64encode(buf.read()).decode('utf-8')
129
- buf.close()
130
-
131
- return result, f'<img src="data:image/png;base64,{image_base64}" />'
 
132
 
133
  # Create a Gradio interface
134
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
 
61
  # Load the Hugging Face model and tokenizer
62
 
63
  metadata_features = 0
64
+ N_UNIQUE_CLASSES = 38
65
 
66
  base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
67
  tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
 
84
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
85
  model.eval()
86
 
 
87
  def analyze_dna(sequence):
88
+ try:
89
+ # Check if the sequence contains only valid characters
90
+ if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
91
+ raise ValueError("Sequence contains invalid characters")
92
 
93
+ # Check if the sequence is at least 300 nucleotides long
94
+ if len(sequence) < 300:
95
+ raise ValueError("Sequence needs to be at least 300 nucleotides long")
96
+
97
+ # Preprocess the input sequence
98
+ inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
99
+
100
+ # Get model predictions
101
+ _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
102
+
103
+ # Convert logits to probabilities
104
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
105
+
106
+ # Get the top 5 most likely classes
107
+ top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
108
+ top_5_probs = [probabilities[i] for i in top_5_indices]
109
+
110
+ # Map indices to label names
111
+ top_5_labels = [int_to_label[i] for i in top_5_indices]
112
+
113
+ # Prepare the output as a list of tuples (label_name, probability)
114
+ result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
115
+
116
+ # Plot histogram
117
+ fig, ax = plt.subplots(figsize=(10, 6))
118
+ ax.barh(top_5_labels, top_5_probs, color='skyblue')
119
+ ax.set_xlabel('Probability')
120
+ ax.set_title('Top 5 Most Likely Labels')
121
+ plt.gca().invert_yaxis() # Highest probabilities at the top
122
+
123
+ # Save plot to a PNG image in memory
124
+ buf = io.BytesIO()
125
+ plt.savefig(buf, format='png')
126
+ buf.seek(0)
127
+ image_base64 = base64.b64encode(buf.read()).decode('utf-8')
128
+ buf.close()
129
+
130
+ return result, f'<img src="data:image/png;base64,{image_base64}" />'
131
+
132
+ except ValueError as e:
133
+ # Return the error message
134
+ return str(e), ""
135
+
136
 
137
  # Create a Gradio interface
138
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])