sagawa commited on
Commit
9329e39
·
verified ·
1 Parent(s): 33e1d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -8
app.py CHANGED
@@ -2,23 +2,51 @@ import gradio as gr
2
  import sys
3
  import random
4
  import os
 
 
 
 
5
  sys.path.append("scripts/")
6
  from foldseek_util import get_struc_seq
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Assuming 'predict_stability' is your function that predicts protein stability
9
- def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None):
10
  # Check if pdb_file is provided
11
  if pdb_file:
12
  pdb_path = pdb_file.name # Get the path of the uploaded PDB file
13
  os.system("chmod 777 bin/foldseek")
14
- sequence = get_foldseek_seq(pdb_path)
15
- if not sequence:
16
  return "Failed to extract sequence from the PDB file."
17
-
 
 
 
 
 
 
 
 
18
  # If sequence is provided directly
19
  if sequence:
20
- # Add logic to predict stability using the sequence
21
- return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {sequence[:10]}..."
 
 
 
22
  else:
23
  return "No valid input provided."
24
 
@@ -33,6 +61,56 @@ def get_foldseek_seq(pdb_path):
33
  return parsed_seqs
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Gradio Interface
37
  with gr.Blocks() as demo:
38
  gr.Markdown(
@@ -47,7 +125,7 @@ with gr.Blocks() as demo:
47
  # Model and Organism selection in the same row to avoid layout issues
48
  with gr.Row():
49
  model_choice = gr.Radio(
50
- choices=["SaProt", "ESM-2"],
51
  label="Select PLTNUM's base model.",
52
  value="SaProt"
53
  )
@@ -82,7 +160,7 @@ with gr.Blocks() as demo:
82
  gr.Markdown(
83
  """
84
  ### How to Use:
85
- - **Select Model**: Choose between 'SaProt' or 'ESM-2' for your prediction.
86
  - **Select Organism**: Choose between 'Mouse' or 'Human'.
87
  - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
88
  - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.
 
2
  import sys
3
  import random
4
  import os
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from transformers import AutoTokenizer
9
  sys.path.append("scripts/")
10
  from foldseek_util import get_struc_seq
11
+ from utils import seed_everything
12
+ from models import PLTNUM_PreTrainedModel
13
+ from datasets import PLTNUMDataset
14
+
15
+ class Config:
16
+ batch_size = 2
17
+ use_amp = False
18
+ num_workers = 1
19
+ max_length = 512
20
+ used_sequence = "left"
21
+ padding_side = "right"
22
+ task = "classification"
23
+ sequence_col = "sequence"
24
 
25
  # Assuming 'predict_stability' is your function that predicts protein stability
26
+ def predict_stability(cfg, model_choice, organism_choice, pdb_file=None, sequence=None):
27
  # Check if pdb_file is provided
28
  if pdb_file:
29
  pdb_path = pdb_file.name # Get the path of the uploaded PDB file
30
  os.system("chmod 777 bin/foldseek")
31
+ sequences = get_foldseek_seq(pdb_path)
32
+ if not sequences:
33
  return "Failed to extract sequence from the PDB file."
34
+ if model_choice == "SaProt":
35
+ sequence = sequences[2]
36
+ else:
37
+ sequence = sequences[0]
38
+
39
+ if organism_choice == "Human":
40
+ cell_line = "HeLa"
41
+ else:
42
+ cell_line = "NIH3T3"
43
  # If sequence is provided directly
44
  if sequence:
45
+ cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
46
+ cfg.architecture = model_choice
47
+ cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
48
+ output = predict(cfg, sequence)
49
+ return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {sequence}..."
50
  else:
51
  return "No valid input provided."
52
 
 
61
  return parsed_seqs
62
 
63
 
64
+ def predict(cfg, sequence):
65
+ cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
66
+ cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
67
+
68
+ if cfg.used_sequence == "both":
69
+ cfg.max_length += 1
70
+
71
+ seed_everything(cfg.seed)
72
+
73
+ df = pd.DataFrame({cfg.sequence_col: [sequence]})
74
+
75
+ tokenizer = AutoTokenizer.from_pretrained(
76
+ cfg.model_path, padding_side=cfg.padding_side
77
+ )
78
+ cfg.tokenizer = tokenizer
79
+
80
+ dataset = PLTNUMDataset(cfg, df, train=False)
81
+ dataloader = DataLoader(
82
+ dataset,
83
+ batch_size=cfg.batch_size,
84
+ shuffle=False,
85
+ num_workers=cfg.num_workers,
86
+ pin_memory=True,
87
+ drop_last=False,
88
+ )
89
+
90
+ model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
91
+ model.to(cfg.device)
92
+
93
+ # predictions = predict_fn(loader, model, cfg)
94
+ model.eval()
95
+ predictions = []
96
+
97
+ for inputs, _ in dataloader:
98
+ inputs = inputs.to(cfg.device)
99
+ with torch.no_grad():
100
+ with torch.amp.autocast(enabled=cfg.use_amp):
101
+ preds = (
102
+ torch.sigmoid(model(inputs))
103
+ if cfg.task == "classification"
104
+ else model(inputs)
105
+ )
106
+ predictions += preds.cpu().tolist()
107
+ outputs = {}
108
+ outputs["raw prediction values"] = predictions
109
+ outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
110
+ return outputs
111
+
112
+
113
+
114
  # Gradio Interface
115
  with gr.Blocks() as demo:
116
  gr.Markdown(
 
125
  # Model and Organism selection in the same row to avoid layout issues
126
  with gr.Row():
127
  model_choice = gr.Radio(
128
+ choices=["SaProt", "ESM2"],
129
  label="Select PLTNUM's base model.",
130
  value="SaProt"
131
  )
 
160
  gr.Markdown(
161
  """
162
  ### How to Use:
163
+ - **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction.
164
  - **Select Organism**: Choose between 'Mouse' or 'Human'.
165
  - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
166
  - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.