Simon Duerr commited on
Commit
ed7e222
·
1 Parent(s): a5f62d5

add UI alpha

Browse files
Files changed (3) hide show
  1. README.md +2 -1
  2. app.py +315 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
10
  license: mit
11
  ---
12
 
13
+ UI for RoseTTAfold2 All Atom version built by @simonduerr
14
+
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Input UI for RoseTTAfold All Atom
3
+
4
+ using two custom gradio components: gradio_molecule3d and gradio_cofoldinginput
5
+ """
6
+
7
+
8
+ import gradio as gr
9
+ from gradio_cofoldinginput import CofoldingInput
10
+
11
+ from gradio_molecule3d import Molecule3D
12
+
13
+ import json
14
+ import yaml
15
+ from openbabel import openbabel
16
+
17
+ import zipfile
18
+ import tempfile
19
+
20
+ import os
21
+
22
+ from Bio.PDB import PDBParser, PDBIO
23
+
24
+ baseconfig = """job_name: "structure_prediction"
25
+ output_path: ""
26
+ checkpoint_path: RFAA_paper_weights.pt
27
+ database_params:
28
+ sequencedb: ""
29
+ hhdb: "pdb100_2021Mar03/pdb100_2021Mar03"
30
+ command: make_msa.sh
31
+ num_cpus: 4
32
+ mem: 64
33
+ protein_inputs: null
34
+ na_inputs: null
35
+ sm_inputs: null
36
+ covale_inputs: null
37
+ residue_replacement: null
38
+
39
+ chem_params:
40
+ use_phospate_frames_for_NA: True
41
+ use_cif_ordering_for_trp: True
42
+
43
+ loader_params:
44
+ n_templ: 4
45
+ MAXLAT: 128
46
+ MAXSEQ: 1024
47
+ MAXCYCLE: 4
48
+ BLACK_HOLE_INIT: False
49
+ seqid: 150.0
50
+
51
+
52
+ legacy_model_param:
53
+ n_extra_block: 4
54
+ n_main_block: 32
55
+ n_ref_block: 4
56
+ n_finetune_block: 0
57
+ d_msa: 256
58
+ d_msa_full: 64
59
+ d_pair: 192
60
+ d_templ: 64
61
+ n_head_msa: 8
62
+ n_head_pair: 6
63
+ n_head_templ: 4
64
+ d_hidden_templ: 64
65
+ p_drop: 0.0
66
+ use_chiral_l1: True
67
+ use_lj_l1: True
68
+ use_atom_frames: True
69
+ recycling_type: "all"
70
+ use_same_chain: True
71
+ lj_lin: 0.75
72
+ SE3_param:
73
+ num_layers: 1
74
+ num_channels: 32
75
+ num_degrees: 2
76
+ l0_in_features: 64
77
+ l0_out_features: 64
78
+ l1_in_features: 3
79
+ l1_out_features: 2
80
+ num_edge_features: 64
81
+ n_heads: 4
82
+ div: 4
83
+ SE3_ref_param:
84
+ num_layers: 2
85
+ num_channels: 32
86
+ num_degrees: 2
87
+ l0_in_features: 64
88
+ l0_out_features: 64
89
+ l1_in_features: 3
90
+ l1_out_features: 2
91
+ num_edge_features: 64
92
+ n_heads: 4
93
+ div: 4
94
+ """
95
+
96
+ def convert_format(input_file, jobname, chain, deleteIndexes, attachmentIndex):
97
+
98
+ conv = openbabel.OBConversion()
99
+ conv.SetInAndOutFormats('cdjson', 'sdf')
100
+
101
+ # Add options
102
+ conv.AddOption("c", openbabel.OBConversion.OUTOPTIONS, "1")
103
+ with open(f"{jobname}_sm_{chain}.json", "w+") as fp:
104
+ fp.write(input_file)
105
+ mol = openbabel.OBMol()
106
+ conv.ReadFile(mol, f"{jobname}_sm_{chain}.json")
107
+
108
+ deleted_count = 0
109
+ # delete atoms in delete indexes
110
+ for index in sorted(deleteIndexes, reverse=True):
111
+ if index < attachmentIndex:
112
+ deleted_count += 1
113
+ atom = mol.GetAtom(index)
114
+ mol.DeleteAtom(atom)
115
+
116
+ attachmentIndex -= deleted_count
117
+
118
+ conv.WriteFile(mol, f"{jobname}_sm_{chain}.sdf")
119
+ return attachmentIndex
120
+
121
+
122
+ def prepare_input(input, jobname, baseconfig, hard_case):
123
+ input_categories = {"protein":"protein_inputs", "DNA":"na_inputs","RNA":"na_inputs", "ligand":"sm_inputs"}
124
+
125
+ # convert input to yaml format
126
+ yaml_dict = {"defaults":["base"], "job_name":jobname, "output_path": jobname}
127
+ list_of_input_files = []
128
+
129
+ if len(input["chains"]) == 0:
130
+ raise gr.Error("At least one chain must be provided")
131
+ for chain in input["chains"]:
132
+ if input_categories[chain["class"]] not in yaml_dict.keys():
133
+ yaml_dict[input_categories[chain["class"]]] = {}
134
+
135
+ if input_categories[chain["class"]] in ["protein_inputs", "na_inputs"]:
136
+ #write fasta
137
+ with open(f"{jobname}_{chain['chain']}.fasta", "w+") as fp:
138
+ fp.write(f">chain A\n{chain['sequence']}")
139
+ if input_categories[chain["class"]] == "na_inputs":
140
+ entry = {"input_type":chain["class"].lower(), "fasta":f"{jobname}/{jobname}_{chain['chain']}.fasta"}
141
+ else:
142
+ entry = {"fasta_file": f"{jobname}/{jobname}_{chain['chain']}.fasta"}
143
+ list_of_input_files.append(f"{jobname}_{chain['chain']}.fasta")
144
+ yaml_dict[input_categories[chain["class"]]][chain['chain']] = entry
145
+
146
+ if input_categories[chain['class']] == "sm_inputs":
147
+ if "smiles" in chain.keys():
148
+ entry = {"input_type": "smiles", "input": chain["smiles"]}
149
+ elif "sdf" in chain.keys():
150
+ # write to file
151
+ with open(f"{jobname}_sm_{chain['chain']}.sdf", "w+") as fp:
152
+ fp.write(chain["sdf"])
153
+ list_of_input_files.append(f"{jobname}_sm_{chain['chain']}.sdf")
154
+ entry = {"input_type": "sdf", "input": f"{jobname}/{jobname}_sm_{chain['chain']}.sdf"}
155
+ elif "name" in chain.keys():
156
+ list_of_input_files.append(f"metal_sdf/{chain['name']}_ideal.sdf")
157
+ entry = {"input_type": "sdf", "input": f"{jobname}/{chain['name']}_ideal.sdf"}
158
+ yaml_dict["sm_inputs"][chain['chain']] = entry
159
+
160
+ covale_inputs = []
161
+ if len(input["covMods"])>0:
162
+ yaml_dict["covale_inputs"]=""
163
+
164
+ for covMod in input["covMods"]:
165
+ if len(covMod["deleteIndexes"])>0:
166
+ new_attachment_index = convert_format(covMod["mol"],jobname, covMod["ligand"], covMod["deleteIndexes"], covMod["attachmentIndex"])
167
+ chirality_ligand = "null"
168
+ chirality_protein = "null"
169
+ if covMod["protein_symmetry"] in ["CW", "CCW"]:
170
+ chirality_protein = covMod["protein_symmetry"]
171
+ if covMod["ligand_symmetry"] in ["CW", "CCW"]:
172
+ chirality_ligand = covMod["ligand_symmetry"]
173
+ covale_inputs.append(((covMod[ "protein"], covMod["residue"], covMod["atom"]), (covMod["ligand"], new_attachment_index), (chirality_protein, chirality_ligand)))
174
+ if len(input["covMods"])>0:
175
+ yaml_dict["covale_inputs"] = json.dumps(json.dumps(covale_inputs))[1:-1].replace("'", "\"")
176
+
177
+ if hard_case:
178
+ yaml_dict["loader_params"]= {}
179
+ yaml_dict["loader_params"]["MAXCYCLE"] = 10
180
+ # write yaml to tmp
181
+ with open(f"/tmp/{jobname}.yaml", "w+") as fp:
182
+ # need to convert single quotes to double quotes
183
+ fp.write(yaml.dump(yaml_dict).replace("'", "\""))
184
+
185
+ # write baseconfig
186
+ with open(f"/tmp/base.yaml", "w+") as fp:
187
+ fp.write(baseconfig)
188
+
189
+ list_of_input_files.append(f"/tmp/{jobname}.yaml")
190
+ list_of_input_files.append(f"/tmp/base.yaml")
191
+ # convert dictionary to YAML
192
+ with zipfile.ZipFile(os.path.join("/tmp/", f"{jobname}.zip"), 'w') as zip_archive:
193
+ for file in set(list_of_input_files):
194
+ zip_archive.write(file, arcname= os.path.join(jobname,os.path.basename(file)),compress_type=zipfile.ZIP_DEFLATED)
195
+
196
+ return yaml.dump(yaml_dict).replace("'", "\""),os.path.join("/tmp/", f"{jobname}.zip")
197
+
198
+ def run_rf2aa(jobname, zip_archive):
199
+ current_dir = os.getcwd()
200
+ try:
201
+ with zipfile.ZipFile(zip_archive, 'r') as zip_ref:
202
+ zip_ref.extractall(os.path.join(current_dir))
203
+ os.system(f"python -m rf2aa.run_inference --config-name {jobname}.yaml --config-path {current_dir}/{jobname}")
204
+ # scale pLDDT to 0-100 range in pdb output file
205
+ parser = PDBParser(QUIET=True)
206
+ structure = parser.get_structure(jobname, f"{current_dir}/{jobname}/{jobname}.pdb")
207
+ for model in structure:
208
+ for chain in model:
209
+ for residue in chain:
210
+ for atom in residue:
211
+ atom.bfactor = atom.bfactor * 100
212
+ io = PDBIO()
213
+ io.set_structure(structure)
214
+ io.save(f"{current_dir}/{jobname}/{jobname}.pdb")
215
+
216
+ except Exception as e:
217
+ raise gr.Error(f"Error running RFAA: {e}")
218
+ return f"{current_dir}/{jobname}/{jobname}.pdb"
219
+
220
+
221
+
222
+ def predict(input, jobname, dry_run, baseconfig, hard_case):
223
+ yaml_input, zip_archive = prepare_input(input, jobname, baseconfig, hard_case)
224
+
225
+ reps = []
226
+
227
+ for chain in input["chains"]:
228
+ if chain["class"] in ["protein", "RNA", "DNA"]:
229
+ reps.append({
230
+ "model": 0,
231
+ "chain": chain["chain"],
232
+ "resname": "",
233
+ "style": "cartoon",
234
+ "color": "alphafold",
235
+ "residue_range": "",
236
+ "around": 0,
237
+ "byres": False
238
+ })
239
+ elif chain["class"] == "ligand" and "name" not in chain.keys():
240
+ reps.append({
241
+ "model": 0,
242
+ "chain": chain["chain"],
243
+ "resname": "LG1",
244
+ "style": "stick",
245
+ "color": "whiteCarbon",
246
+ "residue_range": "",
247
+ "around": 0,
248
+ "byres": False
249
+ })
250
+ else:
251
+ reps.append({
252
+ "model": 0,
253
+ "chain": chain["chain"],
254
+ "resname": "LG1",
255
+ "style": "sphere",
256
+ "color": "whiteCarbon",
257
+ "residue_range": "",
258
+ "around": 0,
259
+ "byres": False
260
+ })
261
+ if dry_run:
262
+ return gr.Code(yaml_input, visible=True), gr.File(zip_archive, visible=True), gr.Markdown(f"""You can run your RFAA job using the following command: <pre>python -m rf2aa.run_inference --config-name {jobname}.yaml --config-path absolute/path/to/unzipped/{jobname}</pre>""", visible=True), Molecule3D(visible=False)
263
+ else:
264
+ pdb_file = run_rf2aa(jobname, zip_archive)
265
+ return gr.Code(yaml_input, visible=True), gr.File(zip_archive, visible=True),gr.Markdown(visible=False), Molecule3D(pdb_file,reps=reps,visible=True)
266
+
267
+ with gr.Blocks() as demo:
268
+ gr.Markdown("# RoseTTAFold All Atom UI")
269
+ gr.Markdown("""This UI allows you to generate input files for RoseTTAFold All Atom (RFAA) using the CofoldingInput widget. The input files can be used to run RFAA on your local machine. <br />
270
+ If you launch the UI directly on your local machine you can also directly run the RFAA prediction. <br />
271
+ More information in the official GitHub repository: [baker-laboratory/RoseTTAFold-All-Atom](https://github.com/baker-laboratory/RoseTTAFold-All-Atom)
272
+ """)
273
+ jobname = gr.Textbox("job1", label="Job Name")
274
+ with gr.Tab("Input"):
275
+ inp=CofoldingInput(label="Input")
276
+ hard_case = gr.Checkbox(False, label="Hard case (increase MAXCYCLE to 10)")
277
+ if os.environ.get("SPACE_HOST")=="":
278
+ dry_run = gr.Checkbox(True, label="Only generate input files (dry run)", interactive=False)
279
+ else:
280
+ dry_run = gr.Checkbox(True, label="Only generate input files (dry run)")
281
+ with gr.Tab("Base config"):
282
+ base_config = gr.Code(baseconfig, label="Base config")
283
+ btn = gr.Button("Run")
284
+ config_file = gr.Code(label="YAML Hydra config for RFAA", visible=True)
285
+ runfiles = gr.File(label="files to run RFAA", visible=False)
286
+ instructions = gr.Markdown(visible=False)
287
+
288
+ # reps = [
289
+ # {
290
+ # "model": 0,
291
+ # "chain": "",
292
+ # "resname": "",
293
+ # "style": "cartoon",
294
+ # "color": "alphafold",
295
+ # "residue_range": "",
296
+ # "around": 0,
297
+ # "byres": False
298
+ # },
299
+ # {
300
+ # "model": 0,
301
+ # "chain": "",
302
+ # "resname": "LG1",
303
+ # "style": "stick",
304
+ # "color": "whiteCarbon",
305
+ # "residue_range": "",
306
+ # "around": 0,
307
+ # "byres": False,
308
+ # }
309
+ # ]
310
+ out = Molecule3D(visible=False)
311
+
312
+ btn.click(predict, inputs=[inp, jobname, dry_run, base_config, hard_case], outputs=[config_file, runfiles, instructions, out])
313
+
314
+ if __name__ == "__main__":
315
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio_molecule3d
2
+ gradio_cofoldinginput
3
+ openbabel-wheel