Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -83,12 +83,16 @@ log_reg.load_state_dict(weights['log_reg_state_dict'])
|
|
83 |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
84 |
model.eval()
|
85 |
|
|
|
|
|
|
|
|
|
86 |
# Define a function to process the DNA sequence
|
87 |
def analyze_dna(sequence):
|
88 |
# Preprocess the input sequence
|
89 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
90 |
|
91 |
-
print("
|
92 |
# Get model predictions
|
93 |
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
94 |
|
@@ -97,13 +101,16 @@ def analyze_dna(sequence):
|
|
97 |
# Convert logits to probabilities
|
98 |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
|
99 |
|
100 |
-
print("Probabilities
|
101 |
# Get the top 5 most likely classes
|
102 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
103 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
104 |
|
105 |
-
#
|
106 |
-
|
|
|
|
|
|
|
107 |
|
108 |
return result
|
109 |
|
@@ -112,3 +119,5 @@ demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
|
|
112 |
|
113 |
# Launch the interface
|
114 |
demo.launch()
|
|
|
|
|
|
83 |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
84 |
model.eval()
|
85 |
|
86 |
+
# Dictionary to decode model predictions
|
87 |
+
label_to_int = pd.read_pkl('label_to_int.pkl')
|
88 |
+
int_to_label = {v: k for k, v in label_to_int.items()}
|
89 |
+
|
90 |
# Define a function to process the DNA sequence
|
91 |
def analyze_dna(sequence):
|
92 |
# Preprocess the input sequence
|
93 |
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
94 |
|
95 |
+
print("Tokenization done.")
|
96 |
# Get model predictions
|
97 |
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
98 |
|
|
|
101 |
# Convert logits to probabilities
|
102 |
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
|
103 |
|
104 |
+
print("Probabilities done.")
|
105 |
# Get the top 5 most likely classes
|
106 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
107 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
108 |
|
109 |
+
# Map indices to label names
|
110 |
+
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
111 |
+
|
112 |
+
# Prepare the output as a list of tuples (label_name, probability)
|
113 |
+
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
|
114 |
|
115 |
return result
|
116 |
|
|
|
119 |
|
120 |
# Launch the interface
|
121 |
demo.launch()
|
122 |
+
|
123 |
+
|