sagawa commited on
Commit
3bb8b09
·
verified ·
1 Parent(s): 104ce05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -40
app.py CHANGED
@@ -16,45 +16,39 @@ from datasets_ import PLTNUMDataset
16
 
17
 
18
  class Config:
19
- batch_size = 2
20
- use_amp = False
21
- num_workers = 1
22
- max_length = 512
23
- used_sequence = "left"
24
- padding_side = "right"
25
- task = "classification"
26
- sequence_col = "sequence"
27
- seed = 42
28
-
29
-
30
- # Assuming 'predict_stability' is your function that predicts protein stability
31
  def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
32
- # Check if pdb_file is provided
33
  if pdb_file:
34
  pdb_path = pdb_file.name # Get the path of the uploaded PDB file
35
  os.system("chmod 777 bin/foldseek")
36
  sequences = get_foldseek_seq(pdb_path)
37
  if not sequences:
38
  return "Failed to extract sequence from the PDB file."
39
- if model_choice == "SaProt":
40
- sequence = sequences[2]
41
- else:
42
- sequence = sequences[0]
43
-
44
- if organism_choice == "Human":
45
- cell_line = "HeLa"
46
- else:
47
- cell_line = "NIH3T3"
48
- # If sequence is provided directly
49
- if sequence:
50
- cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
51
- cfg.architecture = model_choice
52
- cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
53
- output = predict(cfg, sequence)
54
- return output
55
- else:
56
  return "No valid input provided."
57
 
 
 
 
 
 
 
 
 
 
58
 
59
  def get_foldseek_seq(pdb_path):
60
  parsed_seqs = get_struc_seq(
@@ -93,7 +87,6 @@ def predict(cfg, sequence):
93
  model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
94
  model.to(cfg.device)
95
 
96
- # predictions = predict_fn(loader, model, cfg)
97
  model.eval()
98
  predictions = []
99
 
@@ -107,24 +100,23 @@ def predict(cfg, sequence):
107
  else model(inputs)
108
  )
109
  predictions += preds.cpu().tolist()
110
- outputs = {}
111
  predictions = list(itertools.chain.from_iterable(predictions))
112
- outputs["raw prediction values"] = predictions
113
- outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
 
 
114
 
115
  html_output = f"""
116
  <div style='border: 2px solid #4CAF50; padding: 10px; border-radius: 10px;'>
117
- <h3 style='color: #4CAF50;'>Protein: {sequence}</h3>
118
- <p><strong>Stability:</strong> {outputs['raw prediction values']}</p>
119
- <p><strong>Organism:</strong> {outputs['binary prediction values']}</p>
120
  </div>
121
  """
122
 
123
  return html_output
124
 
125
 
126
-
127
-
128
  # Gradio Interface
129
  with gr.Blocks() as demo:
130
  gr.Markdown(
@@ -156,7 +148,6 @@ with gr.Blocks() as demo:
156
  with gr.TabItem("Upload PDB File"):
157
  gr.Markdown("### Upload your PDB file:")
158
  pdb_file = gr.File(label="Upload PDB File")
159
-
160
  predict_button = gr.Button("Predict Stability")
161
  prediction_output = gr.HTML(
162
  label="Stability Prediction"
 
16
 
17
 
18
  class Config:
19
+ def __init__(self):
20
+ self.batch_size = 2
21
+ self.use_amp = False
22
+ self.num_workers = 1
23
+ self.max_length = 512
24
+ self.used_sequence = "left"
25
+ self.padding_side = "right"
26
+ self.task = "classification"
27
+ self.sequence_col = "sequence"
28
+ self.seed = 42
29
+
30
+
31
  def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
 
32
  if pdb_file:
33
  pdb_path = pdb_file.name # Get the path of the uploaded PDB file
34
  os.system("chmod 777 bin/foldseek")
35
  sequences = get_foldseek_seq(pdb_path)
36
  if not sequences:
37
  return "Failed to extract sequence from the PDB file."
38
+ sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
39
+
40
+ if not sequence:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return "No valid input provided."
42
 
43
+ cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
44
+ cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
45
+ cfg.architecture = model_choice
46
+ cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
47
+
48
+ output = predict(cfg, sequence)
49
+ return output
50
+
51
+
52
 
53
  def get_foldseek_seq(pdb_path):
54
  parsed_seqs = get_struc_seq(
 
87
  model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
88
  model.to(cfg.device)
89
 
 
90
  model.eval()
91
  predictions = []
92
 
 
100
  else model(inputs)
101
  )
102
  predictions += preds.cpu().tolist()
103
+
104
  predictions = list(itertools.chain.from_iterable(predictions))
105
+ outputs = {
106
+ "raw prediction values": predictions,
107
+ "binary prediction values": [1 if x > 0.5 else 0 for x in predictions]
108
+ }
109
 
110
  html_output = f"""
111
  <div style='border: 2px solid #4CAF50; padding: 10px; border-radius: 10px;'>
112
+ <p><strong>Stability:</strong> {outputs['raw prediction values'][0]}</p>
113
+ <p><strong>Organism:</strong> {outputs['binary prediction values'][0]}</p>
 
114
  </div>
115
  """
116
 
117
  return html_output
118
 
119
 
 
 
120
  # Gradio Interface
121
  with gr.Blocks() as demo:
122
  gr.Markdown(
 
148
  with gr.TabItem("Upload PDB File"):
149
  gr.Markdown("### Upload your PDB file:")
150
  pdb_file = gr.File(label="Upload PDB File")
 
151
  predict_button = gr.Button("Predict Stability")
152
  prediction_output = gr.HTML(
153
  label="Stability Prediction"