mawairon commited on
Commit
1f65033
1 Parent(s): 9f4c137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -83,12 +83,16 @@ log_reg.load_state_dict(weights['log_reg_state_dict'])
83
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
84
  model.eval()
85
 
 
 
 
 
86
  # Define a function to process the DNA sequence
87
  def analyze_dna(sequence):
88
  # Preprocess the input sequence
89
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
90
 
91
- print("tokenization done.")
92
  # Get model predictions
93
  _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
94
 
@@ -97,13 +101,16 @@ def analyze_dna(sequence):
97
  # Convert logits to probabilities
98
  probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
99
 
100
- print("Probabilities, done.")
101
  # Get the top 5 most likely classes
102
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
103
  top_5_probs = [probabilities[i] for i in top_5_indices]
104
 
105
- # Prepare the output as a list of tuples (class_index, probability)
106
- result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
 
 
 
107
 
108
  return result
109
 
@@ -112,3 +119,5 @@ demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
112
 
113
  # Launch the interface
114
  demo.launch()
 
 
 
83
  model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
84
  model.eval()
85
 
86
+ # Dictionary to decode model predictions
87
+ label_to_int = pd.read_pkl('label_to_int.pkl')
88
+ int_to_label = {v: k for k, v in label_to_int.items()}
89
+
90
  # Define a function to process the DNA sequence
91
  def analyze_dna(sequence):
92
  # Preprocess the input sequence
93
  inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
94
 
95
+ print("Tokenization done.")
96
  # Get model predictions
97
  _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
98
 
 
101
  # Convert logits to probabilities
102
  probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
103
 
104
+ print("Probabilities done.")
105
  # Get the top 5 most likely classes
106
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
107
  top_5_probs = [probabilities[i] for i in top_5_indices]
108
 
109
+ # Map indices to label names
110
+ top_5_labels = [int_to_label[i] for i in top_5_indices]
111
+
112
+ # Prepare the output as a list of tuples (label_name, probability)
113
+ result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
114
 
115
  return result
116
 
 
119
 
120
  # Launch the interface
121
  demo.launch()
122
+
123
+