import gradio as gr import py3Dmol import io import numpy as np import os import traceback from esm.sdk import client from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError from esm.utils.structure.protein_chain import ProteinChain from Bio.Data import PDBData import biotite.structure as bs from biotite.structure.io import pdb from esm.utils import residue_constants as RC import requests from dotenv import load_dotenv import torch import json import time from Bio.PDB import PDBParser import itertools load_dotenv() API_URL = "https://forge.evolutionaryscale.ai/api/v1" MODEL = "esm3-open-2024-03" API_TOKEN = os.environ.get("ESM_API_TOKEN") if not API_TOKEN: raise ValueError("ESM_API_TOKEN environment variable is not set") model = client( model=MODEL, url=API_URL, token="2x0lifRJCpo8klurAJtRom" ) amino3to1 = { 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' } # Covalent radii dictionary COVALENT_RADIUS = { "H": 0.31, "HE": 0.28, "LI": 1.28, "BE": 0.96, "B": 0.84, "C": 0.76, "N": 0.71, "O": 0.66, "F": 0.57, "NE": 0.58, "NA": 1.66, "MG": 1.41, "AL": 1.21, "SI": 1.11, "P": 1.07, "S": 1.05, "CL": 1.02, "AR": 1.06, "K": 2.03, "CA": 1.76, "SC": 1.7, "TI": 1.6, "V": 1.53, "CR": 1.39, "MN": 1.5, "FE": 1.42, "CO": 1.38, "NI": 1.24, "CU": 1.32, "ZN": 1.22, "GA": 1.22, "GE": 1.2, "AS": 1.19, "SE": 1.2, "BR": 1.2, "KR": 1.16, "RB": 2.2, "SR": 1.95, "Y": 1.9, "ZR": 1.75, "NB": 1.64, "MO": 1.54, "TC": 1.47, "RU": 1.46, "RH": 1.42, "PD": 1.39, "AG": 1.45, "CD": 1.44, "IN": 1.42, "SN": 1.39, "SB": 1.39, "TE": 1.38, "I": 1.39, "XE": 1.4, "CS": 2.44, "BA": 2.15, "LA": 2.07, "CE": 2.04, "PR": 2.03, "ND": 2.01, "PM": 1.99, "SM": 1.98, "EU": 1.98, "GD": 1.96, "TB": 1.94, "DY": 1.92, "HO": 1.92, "ER": 1.89, "TM": 1.9, "YB": 1.87, "LU": 1.87, "HF": 1.75, "TA": 1.7, "W": 1.62, "RE": 1.51, "OS": 1.44, "IR": 1.41, "PT": 1.36, "AU": 1.36, "HG": 1.32, "TL": 1.45, "PB": 1.46, "BI": 1.48, "PO": 1.4, "AT": 1.5, "RN": 1.5, "FR": 2.6, "RA": 2.21, "AC": 2.15, "TH": 2.06, "PA": 2.0, "U": 1.96, "NP": 1.9, "PU": 1.87, "AM": 1.8, "CM": 1.69, "BK": 2.0, "CF": 2.0, "ES": 2.0, "FM": 2.0, "MD": 2.0, "NO": 2.0, "LR": 2.0, "RF": 2.0, "DB": 2.0, "SG": 2.0, "BH": 2.0, "HS": 2.0, "MT": 2.0, "DS": 2.0, "RG": 2.0, "CN": 2.0, "UUT": 2.0, "UUQ": 2.0, "UUP": 2.0, "UUH": 2.0, "UUS": 2.0, "UUO": 2.0 } # Function to get the covalent radius of an atom def get_covalent_radius(atom): element = atom.element.upper() return COVALENT_RADIUS.get(element, 2.0) # Default to 2.0 Å if element is not in the dictionary def calculate_clashes_for_pdb(pdb_file): parser = PDBParser(QUIET=True) structure = parser.get_structure("protein", pdb_file) atoms = list(structure.get_atoms()) steric_clash_count = 0 num_atoms = len(atoms) # Check atom pairs for steric clashes for atom1, atom2 in itertools.combinations(atoms, 2): covalent_radius_sum = get_covalent_radius(atom1) + get_covalent_radius(atom2) distance = atom1 - atom2 # Distance between atom1 and atom2 # Check if the distance is less than the sum of covalent radii if distance + 0.5 < covalent_radius_sum: steric_clash_count += 1 # Normalize steric clashes per number of atoms norm_ster_clash_count = steric_clash_count / num_atoms return f"Total steric clashes in {pdb_file}: {steric_clash_count}", f"Normalized steric clashes per atom in {pdb_file}: {norm_ster_clash_count}" def read_pdb_io(pdb_file): if isinstance(pdb_file, io.StringIO): pdb_content = pdb_file.getvalue() elif hasattr(pdb_file, 'name'): with open(pdb_file.name, 'r') as f: pdb_content = f.read() else: raise ValueError("Unsupported file type") if not pdb_content.strip(): raise ValueError("The PDB file is empty.") pdb_io = io.StringIO(pdb_content) return pdb_io, pdb_content def get_protein(pdb_file) -> ESMProtein: try: pdb_io, content = read_pdb_io(pdb_file) if not content.strip(): raise ValueError("The PDB file is empty") # Parse the PDB file using biotite pdb_file = pdb.PDBFile.read(pdb_io) structure = pdb_file.get_structure() # Check if the structure contains any atoms if structure.array_length() == 0: raise ValueError("The PDB file does not contain any valid atoms") # Filter for amino acids and create a sequence valid_residues = [] for res in bs.residue_iter(structure): res_name = res.res_name if isinstance(res_name, np.ndarray): res_name = res_name[0] # Take the first element if it's an array if res_name in amino3to1: valid_residues.append(res) if not valid_residues: raise ValueError("No valid amino acid residues found in the PDB file") sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues) # Handle res_id as a potential sequence residue_indices = [] for res in valid_residues: if isinstance(res.res_id, (list, tuple, np.ndarray)): residue_indices.append(res.res_id[0]) # Take the first element if it's a sequence else: residue_indices.append(res.res_id) # Create a ProteinChain object protein_chain = ProteinChain( id="test", sequence=sequence, chain_id="A", entity_id=None, residue_index=np.array(residue_indices, dtype=int), insertion_code=np.full(len(sequence), "", dtype=" 1: coord = coord[0] # Take the first coordinate set if multiple are present protein_chain.atom37_positions[i, idx] = coord protein_chain.atom37_mask[i, idx] = True protein = ESMProtein.from_protein_chain(protein_chain) return protein except Exception as e: print(f"Error processing PDB file: {str(e)}") raise ValueError(f"Unable to process the PDB file: {str(e)}") def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein: """Add Gaussian noise to the atom positions of the protein.""" coordinates = protein.coordinates noise = np.random.randn(*coordinates.shape) * noise_level noisy_coordinates = coordinates + noise return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates) def run_structure_prediction(protein: ESMProtein) -> ESMProtein: structure_prediction_config = GenerationConfig( track="structure", num_steps=10, temperature=0.7, ) try: response = model.generate(protein, structure_prediction_config) if isinstance(response, ESMProtein): return response elif isinstance(response, ESMProteinError): print(f"ESMProteinError during structure prediction: {response.error_msg}") return None else: raise ValueError(f"Unexpected response type: {type(response)}") except Exception as e: print(f"Error during structure prediction: {str(e)}") return None def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]: if structure_prediction is None: return None, float('inf') try: structure_prediction_chain = structure_prediction.to_protein_chain() protein_chain = protein.to_protein_chain() # Ensure both chains have the same length min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence)) structure_indices = np.arange(0, min_length) # Perform alignment aligned_chain = structure_prediction_chain.align( protein_chain, mobile_inds=structure_indices, target_inds=structure_indices ) # Calculate RMSD crmsd = structure_prediction_chain.rmsd( protein_chain, mobile_inds=structure_indices, target_inds=structure_indices ) return ESMProtein.from_protein_chain(aligned_chain), crmsd except AttributeError as e: print(f"Error during alignment: {str(e)}") print(f"Structure prediction type: {type(structure_prediction)}") print(f"Structure prediction attributes: {dir(structure_prediction)}") return None, float('inf') except Exception as e: print(f"Unexpected error during alignment: {str(e)}") return None, float('inf') def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein): if aligned is None: return None viewer = py3Dmol.view(width=800, height=600) viewer.addModel(protein_to_pdb(protein), "pdb") viewer.setStyle({"cartoon": {"color": "lightgrey"}}) viewer.addModel(protein_to_pdb(aligned), "pdb") viewer.setStyle({"model": -1}, {"cartoon": {"color": "lightgreen"}}) viewer.zoomTo() return viewer.render() def protein_to_pdb(protein: ESMProtein): pdb_str = "" for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)): for j, atom in enumerate(RC.atom_types): if not torch.isnan(coords[j][0]): x, y, z = coords[j].tolist() pdb_str += f"ATOM {i*37+j+1:5d} {atom:3s} {aa:3s} A{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f}\n" return pdb_str def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int, progress=gr.Progress()): protein = get_protein(pdb_file) runs = [] total_iterations = num_frames * num_runs progress(0, desc="Starting predictions") for frame in progress.tqdm(range(num_frames), desc="Processing frames"): noisy_protein = add_noise_to_coordinates(protein, noise_level) for i in range(num_runs): progress((frame * num_runs + i + 1) / total_iterations, desc=f"Frame {frame+1}, Run {i+1}") structure_prediction = run_structure_prediction(noisy_protein) if structure_prediction is not None: aligned, crmsd = align_after_prediction(protein, structure_prediction) if aligned is not None: runs.append((crmsd, aligned)) time.sleep(0.1) # Small delay to allow for UI updates if not runs: return None, "No successful predictions" best_aligned = sorted(runs, key=lambda x: x[0])[0] view_data = visualize_after_pred(protein, best_aligned[1]) return view_data, f"Best cRMSD: {best_aligned[0]:.4f}" def run_prediction(pdb_file, num_runs, noise_level, num_frames, progress=gr.Progress()): try: if pdb_file is None: return "Please upload a PDB file.", "No file uploaded" progress(0, desc="Starting prediction") view_data, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames, progress) steric_clash_text, norm_steric_clas_text = calculate_clashes_for_pdb(pdb_file) if view_data is None: return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text progress(0.9, desc="Rendering visualization") html_content = f"""
""" progress(1.0, desc="Completed") return html_content, crmsd_text, steric_clash_text, norm_steric_clas_text except Exception as e: error_message = str(e) stack_trace = traceback.format_exc() return f"""

Error:

{error_message}

Stack Trace:

{stack_trace}
""", "Error occurred" def create_demo(): with gr.Blocks() as demo: gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames") with gr.Row(): with gr.Column(scale=1): pdb_file = gr.File(label="Upload PDB file") num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame") noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level") num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames") run_button = gr.Button("Run Prediction") with gr.Column(scale=2): visualization = gr.HTML(label="3D Visualization") alignment_result = gr.Textbox(label="Alignment Result") run_button.click( fn=run_prediction, inputs=[pdb_file, num_runs, noise_level, num_frames], outputs=[visualization, alignment_result] ) gr.Markdown(""" ## How to use 1. Upload a PDB file using the file uploader. 2. Adjust the number of prediction runs per frame using the slider. 3. Set the noise level to add random perturbations to the structure. 4. Choose the number of MD frames to simulate. 5. Click the "Run Prediction" button to start the process. 6. The 3D visualization will show the original structure (grey) and the best predicted structure (green). 7. The alignment result will display the best cRMSD (lower is better). 8. Total and Normalized (per atom) steric clashes (lower is better) ## About This demo uses the ESM3 model to predict protein structures from PDB files. It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD. """) return demo if __name__ == "__main__": demo = create_demo() demo.queue() demo.launch()