Update app.py
Browse files
app.py
CHANGED
@@ -16,45 +16,39 @@ from datasets_ import PLTNUMDataset
|
|
16 |
|
17 |
|
18 |
class Config:
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
|
32 |
-
# Check if pdb_file is provided
|
33 |
if pdb_file:
|
34 |
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
|
35 |
os.system("chmod 777 bin/foldseek")
|
36 |
sequences = get_foldseek_seq(pdb_path)
|
37 |
if not sequences:
|
38 |
return "Failed to extract sequence from the PDB file."
|
39 |
-
if model_choice == "SaProt"
|
40 |
-
|
41 |
-
|
42 |
-
sequence = sequences[0]
|
43 |
-
|
44 |
-
if organism_choice == "Human":
|
45 |
-
cell_line = "HeLa"
|
46 |
-
else:
|
47 |
-
cell_line = "NIH3T3"
|
48 |
-
# If sequence is provided directly
|
49 |
-
if sequence:
|
50 |
-
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
51 |
-
cfg.architecture = model_choice
|
52 |
-
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
53 |
-
output = predict(cfg, sequence)
|
54 |
-
return output
|
55 |
-
else:
|
56 |
return "No valid input provided."
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def get_foldseek_seq(pdb_path):
|
60 |
parsed_seqs = get_struc_seq(
|
@@ -93,7 +87,6 @@ def predict(cfg, sequence):
|
|
93 |
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
|
94 |
model.to(cfg.device)
|
95 |
|
96 |
-
# predictions = predict_fn(loader, model, cfg)
|
97 |
model.eval()
|
98 |
predictions = []
|
99 |
|
@@ -107,24 +100,23 @@ def predict(cfg, sequence):
|
|
107 |
else model(inputs)
|
108 |
)
|
109 |
predictions += preds.cpu().tolist()
|
110 |
-
|
111 |
predictions = list(itertools.chain.from_iterable(predictions))
|
112 |
-
outputs
|
113 |
-
|
|
|
|
|
114 |
|
115 |
html_output = f"""
|
116 |
<div style='border: 2px solid #4CAF50; padding: 10px; border-radius: 10px;'>
|
117 |
-
<
|
118 |
-
<p><strong>
|
119 |
-
<p><strong>Organism:</strong> {outputs['binary prediction values']}</p>
|
120 |
</div>
|
121 |
"""
|
122 |
|
123 |
return html_output
|
124 |
|
125 |
|
126 |
-
|
127 |
-
|
128 |
# Gradio Interface
|
129 |
with gr.Blocks() as demo:
|
130 |
gr.Markdown(
|
@@ -156,7 +148,6 @@ with gr.Blocks() as demo:
|
|
156 |
with gr.TabItem("Upload PDB File"):
|
157 |
gr.Markdown("### Upload your PDB file:")
|
158 |
pdb_file = gr.File(label="Upload PDB File")
|
159 |
-
|
160 |
predict_button = gr.Button("Predict Stability")
|
161 |
prediction_output = gr.HTML(
|
162 |
label="Stability Prediction"
|
|
|
16 |
|
17 |
|
18 |
class Config:
|
19 |
+
def __init__(self):
|
20 |
+
self.batch_size = 2
|
21 |
+
self.use_amp = False
|
22 |
+
self.num_workers = 1
|
23 |
+
self.max_length = 512
|
24 |
+
self.used_sequence = "left"
|
25 |
+
self.padding_side = "right"
|
26 |
+
self.task = "classification"
|
27 |
+
self.sequence_col = "sequence"
|
28 |
+
self.seed = 42
|
29 |
+
|
30 |
+
|
31 |
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
|
|
|
32 |
if pdb_file:
|
33 |
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
|
34 |
os.system("chmod 777 bin/foldseek")
|
35 |
sequences = get_foldseek_seq(pdb_path)
|
36 |
if not sequences:
|
37 |
return "Failed to extract sequence from the PDB file."
|
38 |
+
sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
|
39 |
+
|
40 |
+
if not sequence:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
return "No valid input provided."
|
42 |
|
43 |
+
cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
|
44 |
+
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
45 |
+
cfg.architecture = model_choice
|
46 |
+
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
|
47 |
+
|
48 |
+
output = predict(cfg, sequence)
|
49 |
+
return output
|
50 |
+
|
51 |
+
|
52 |
|
53 |
def get_foldseek_seq(pdb_path):
|
54 |
parsed_seqs = get_struc_seq(
|
|
|
87 |
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
|
88 |
model.to(cfg.device)
|
89 |
|
|
|
90 |
model.eval()
|
91 |
predictions = []
|
92 |
|
|
|
100 |
else model(inputs)
|
101 |
)
|
102 |
predictions += preds.cpu().tolist()
|
103 |
+
|
104 |
predictions = list(itertools.chain.from_iterable(predictions))
|
105 |
+
outputs = {
|
106 |
+
"raw prediction values": predictions,
|
107 |
+
"binary prediction values": [1 if x > 0.5 else 0 for x in predictions]
|
108 |
+
}
|
109 |
|
110 |
html_output = f"""
|
111 |
<div style='border: 2px solid #4CAF50; padding: 10px; border-radius: 10px;'>
|
112 |
+
<p><strong>Stability:</strong> {outputs['raw prediction values'][0]}</p>
|
113 |
+
<p><strong>Organism:</strong> {outputs['binary prediction values'][0]}</p>
|
|
|
114 |
</div>
|
115 |
"""
|
116 |
|
117 |
return html_output
|
118 |
|
119 |
|
|
|
|
|
120 |
# Gradio Interface
|
121 |
with gr.Blocks() as demo:
|
122 |
gr.Markdown(
|
|
|
148 |
with gr.TabItem("Upload PDB File"):
|
149 |
gr.Markdown("### Upload your PDB file:")
|
150 |
pdb_file = gr.File(label="Upload PDB File")
|
|
|
151 |
predict_button = gr.Button("Predict Stability")
|
152 |
prediction_output = gr.HTML(
|
153 |
label="Stability Prediction"
|