sagawa commited on
Commit
42d0bd4
·
verified ·
1 Parent(s): 20c06cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -30,31 +30,35 @@ class Config:
30
 
31
 
32
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
33
- results = []
 
 
 
 
 
 
34
  for pdb_file in pdb_files:
35
  try:
36
  pdb_path = pdb_file.name
37
  os.system("chmod 777 bin/foldseek")
38
  sequences = get_foldseek_seq(pdb_path)
39
  if not sequences:
40
- results.append({"file_name": pdb_path,
41
- "raw prediction value": None,
42
- "binary prediction value": None
43
- })
44
  continue
45
 
46
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
47
- output = predict_stability_core(model_choice, organism_choice, sequence, cfg)
48
-
49
- results.append({"file_name": pdb_path,
50
- "raw prediction value": output["raw prediction values"][0],
51
- "binary prediction value": output["binary prediction values"][0]
52
- })
53
  except Exception as e:
54
- results.append({"file_name": pdb_file.name,
55
- "raw prediction value": None,
56
- "binary prediction value": None
57
- })
 
 
 
58
 
59
  df = pd.DataFrame(results)
60
  output_csv = "/tmp/predictions.csv"
@@ -72,13 +76,13 @@ def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg
72
  return f"An error occurred: {str(e)}"
73
 
74
 
75
- def predict_stability_core(model_choice, organism_choice, sequence, cfg=Config()):
76
  cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
77
  cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
78
  cfg.architecture = model_choice
79
  cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
80
 
81
- output = predict(cfg, sequence)
82
  return output
83
 
84
 
@@ -92,7 +96,7 @@ def get_foldseek_seq(pdb_path):
92
  return parsed_seqs
93
 
94
 
95
- def predict(cfg, sequence):
96
  cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
97
  cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
98
 
@@ -100,7 +104,7 @@ def predict(cfg, sequence):
100
  cfg.max_length += 1
101
 
102
  seed_everything(cfg.seed)
103
- df = pd.DataFrame({cfg.sequence_col: [sequence]})
104
 
105
  tokenizer = AutoTokenizer.from_pretrained(
106
  cfg.model_path, padding_side=cfg.padding_side
@@ -134,19 +138,8 @@ def predict(cfg, sequence):
134
  predictions += preds.cpu().tolist()
135
 
136
  predictions = list(itertools.chain.from_iterable(predictions))
137
- outputs = {
138
- "raw prediction values": predictions,
139
- "binary prediction values": [1 if x > 0.5 else 0 for x in predictions]
140
- }
141
-
142
- html_output = f"""
143
- <div style='border: 2px solid #4CAF50; padding: 10px; border-radius: 10px;'>
144
- <p><strong>Raw prediction value:</strong> {outputs['raw prediction values'][0]}</p>
145
- <p><strong>Binary prediction values:</strong> {outputs['binary prediction values'][0]}</p>
146
- </div>
147
- """
148
 
149
- return html_output
150
 
151
 
152
  # Gradio Interface
 
30
 
31
 
32
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
33
+ results = {"file_name": [],
34
+ "raw prediction value": [],
35
+ "binary prediction value": []
36
+ }
37
+ file_names = []
38
+ sequences = []
39
+
40
  for pdb_file in pdb_files:
41
  try:
42
  pdb_path = pdb_file.name
43
  os.system("chmod 777 bin/foldseek")
44
  sequences = get_foldseek_seq(pdb_path)
45
  if not sequences:
46
+ results["file_name"].append(pdb_file.name)
47
+ results["raw prediction value"].append(None)
48
+ results["binary prediction value"].append(None)
 
49
  continue
50
 
51
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
52
+ file_names.append(pdb_file.name)
53
+ sequences.append(sequence)
 
 
 
 
54
  except Exception as e:
55
+ results["file_name"].append(pdb_file.name)
56
+ results["raw prediction value"].append(None)
57
+ results["binary prediction value"].append(None)
58
+ raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, sequences, cfg)
59
+ results["file_name"] = results["file_name"] + file_names
60
+ results["raw prediction value"] = results["raw prediction value"] + raw_prediction
61
+ results["binary prediction value"] = results["binary prediction value"] + binary_prediction
62
 
63
  df = pd.DataFrame(results)
64
  output_csv = "/tmp/predictions.csv"
 
76
  return f"An error occurred: {str(e)}"
77
 
78
 
79
+ def predict_stability_core(model_choice, organism_choice, sequences, cfg=Config()):
80
  cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
81
  cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
82
  cfg.architecture = model_choice
83
  cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
84
 
85
+ output = predict(cfg, sequences)
86
  return output
87
 
88
 
 
96
  return parsed_seqs
97
 
98
 
99
+ def predict(cfg, sequences):
100
  cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
101
  cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
102
 
 
104
  cfg.max_length += 1
105
 
106
  seed_everything(cfg.seed)
107
+ df = pd.DataFrame({cfg.sequence_col: sequences})
108
 
109
  tokenizer = AutoTokenizer.from_pretrained(
110
  cfg.model_path, padding_side=cfg.padding_side
 
138
  predictions += preds.cpu().tolist()
139
 
140
  predictions = list(itertools.chain.from_iterable(predictions))
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ return predictions, [1 if x > 0.5 else 0 for x in predictions]
143
 
144
 
145
  # Gradio Interface