blackhole_models_pli_busted / inference_app.py
OleinikovasV's picture
Update inference_app.py
5234d8b verified
import time
import gradio as gr
from gradio_molecule3d import Molecule3D
import numpy as np
from scipy.optimize import differential_evolution, NonlinearConstraint
from biotite.structure.io.pdb import PDBFile
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D
def generate_input_conformer(
ligand_smiles: str,
addHs: bool = False,
minimize_maxIters: int = -1,
) -> Chem.Mol:
_mol = Chem.MolFromSmiles(ligand_smiles)
# need to add Hs to generate sensible conformers
_mol = Chem.AddHs(_mol)
# try embedding molecule using ETKDGv2 (default)
confid = AllChem.EmbedMolecule(
_mol,
useRandomCoords=True,
useBasicKnowledge=True,
maxAttempts=100,
randomSeed=42,
)
if confid != -1:
if minimize_maxIters > 0:
# molecule successfully embedded - minimize
success = AllChem.MMFFOptimizeMolecule(_mol, maxIters=minimize_maxIters)
# 0 if the optimization converged,
# -1 if the forcefield could not be set up,
# 1 if more iterations are required.
if success == 1:
# extend optimization to double the steps (extends by the same amount)
AllChem.MMFFOptimizeMolecule(_mol, maxIters=minimize_maxIters)
else:
# this means EmbedMolecule failed
# try less optimal approach
confid = AllChem.EmbedMolecule(
_mol,
useRandomCoords=True,
useBasicKnowledge=False,
maxAttempts=100,
randomSeed=42,
)
return _mol
def set_protein_to_new_coord(input_pdb_file, new_coord, output_file):
structure = PDBFile.read(input_pdb_file).get_structure()
structure.coord = np.ones_like(structure.coord) * np.array(new_coord)
file = PDBFile()
file.set_structure(structure)
file.write(output_file)
# def optimize_coordinate(points, bound_buffer=15, dmin=6.05):
# bounds = list(
# zip(
# np.average(points, axis=0) - [bound_buffer]*3,
# np.average(points, axis=0) + [bound_buffer]*3
# )
# )
# # Define the constraint function (ensure dmin distance)
# con = NonlinearConstraint(lambda x: np.min(np.linalg.norm(points - x, axis=1)), dmin, 8)
# # Define the objective function (minimize pairwise distance)
# def objective(x):
# return np.sum(np.linalg.norm(points - x, axis=1))
# # Perform differential evolution to find the optimal coordinate
# result = differential_evolution(objective, bounds, constraints=con)
# return result.x, result.fun
def predict(input_sequence, input_ligand, input_msa, input_protein):
start_time = time.time()
# Do inference here
mol = generate_input_conformer(input_ligand)
conf = mol.GetConformer()
# set ligand
for i in range(mol.GetNumAtoms()):
conf.SetAtomPosition(i, Point3D(0,0,0))
molwriter = Chem.SDWriter("test_docking_pose.sdf")
molwriter.write(mol)
# set protein
new_coord = [6.02, 0, 0]
output_file = "test_out.pdb"
set_protein_to_new_coord(input_protein, new_coord, output_file)
# return an output pdb file with the protein and ligand with resname LIG or UNK.
# also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
metrics = {}
end_time = time.time()
run_time = end_time - start_time
return ["test_out.pdb", "test_docking_pose.sdf"], metrics, run_time
with gr.Blocks() as app:
gr.Markdown("# Template for inference")
gr.Markdown("Title, description, and other information about the model")
with gr.Row():
input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
with gr.Row():
input_msa = gr.File(label="Input Protein MSA (A3M)")
input_protein = gr.File(label="Input protein monomer")
# define any options here
# for automated inference the default options are used
# slider_option = gr.Slider(0,10, label="Slider Option")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"",
"COc1ccc(cc1)n2c3c(c(n2)C(=O)N)CCN(C3=O)c4ccc(cc4)N5CCCCC5=O",
"empty_file.a3m",
"test_input.pdb"
],
],
[input_sequence, input_ligand, input_msa, input_protein],
)
reps = [
{
"model": 0,
"style": "sphere",
"color": "grayCarbon",
},
{
"model": 1,
"style": "stick",
"color": "greenCarbon",
}
]
out = Molecule3D(reps=reps)
metrics = gr.JSON(label="Metrics")
run_time = gr.Textbox(label="Runtime")
btn.click(predict, inputs=[input_sequence, input_ligand, input_msa, input_protein], outputs=[out, metrics, run_time])
app.launch()