File size: 17,760 Bytes
a653c8d 1f1f2d8 f74eb0b f2448d2 1f1f2d8 f2448d2 ab3b1db f2448d2 86c05eb afd257c ab3b1db 21e6511 db3880c 21e6511 ab3b1db f2448d2 f74eb0b f2448d2 f74eb0b 1f1f2d8 f2448d2 afd257c 335fce6 afd257c 1f1f2d8 f74eb0b 1f1f2d8 21e6511 63a236d 21e6511 63a236d f2448d2 63a236d f2448d2 63a236d ca86d3c 63a236d f2448d2 63a236d f2448d2 ca86d3c f2448d2 86c05eb fbb545c f2448d2 fbb545c f2448d2 63a236d ca86d3c 21e6511 1f1f2d8 86c05eb 1f1f2d8 86c05eb 21e6511 f2448d2 86c05eb 63a236d 1f1f2d8 f2448d2 86c05eb 1f1f2d8 21e6511 1f1f2d8 9b5ec9d 1f1f2d8 86c05eb 21e6511 9b5ec9d fbb545c 9b5ec9d 63a236d 86c05eb fbb545c 86c05eb fbb545c 86c05eb fbb545c 86c05eb 9b5ec9d 1f1f2d8 9b5ec9d 1f1f2d8 ba8add4 21e6511 db3880c 1f1f2d8 21e6511 1f1f2d8 9b5ec9d 1f1f2d8 21e6511 9b5ec9d 1f1f2d8 21e6511 1f1f2d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 |
import spaces
import gradio as gr
import py3Dmol
import io
import numpy as np
import os
import traceback
from esm.sdk import client
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError
from esm.utils.structure.protein_chain import ProteinChain
from Bio.Data import PDBData
import biotite.structure as bs
from biotite.structure.io import pdb
from esm.utils import residue_constants as RC
import requests
from dotenv import load_dotenv
import torch
import json
import time
from Bio.PDB import PDBParser
import itertools
howtouse = """
## How to use
1. Upload a PDB file using the file uploader.
2. Adjust the number of prediction runs per frame using the slider.
3. Set the noise level to add random perturbations to the structure.
4. Choose the number of MD frames to simulate.
5. Click the "Run Prediction" button to start the process.
6. The 3D visualization will show the original structure (grey) and the best predicted structure (green).
7. The alignment result will display the best cRMSD (lower is better).
8. Total and Normalized (per atom) steric clashes (lower is better)
"""
about = """ ## Background
- 3D protein structures typically come from crystal structures, which are densely packed and lack flexibility.
- Different proteins require varying levels of noise to achieve overlap in conformational space.
- We've developed an adaptability model that predicts the appropriate noise level for each protein.
## Our Approach
1. **Adaptability Model**: Trained on Molecular Dynamics (MD) data, our model predicts flexibility at the atomic level.
2. **Correlation**: The adaptability predictions correlate well with the RMSD (Root Mean Square Deviation) from ESM3 sampling.
3. **Noise Application**: We apply noise to simulate protein flexibility, mimicking MD-like behavior.
"""
about1 = """
## About
This demo uses the ESM3 model to predict protein structures from PDB files.
It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD.
"""
load_dotenv()
API_URL = "https://forge.evolutionaryscale.ai/api/v1"
MODEL = "esm3-open-2024-03"
API_TOKEN = os.environ.get("ESM_API_TOKEN")
if not API_TOKEN:
raise ValueError("ESM_API_TOKEN environment variable is not set")
model = client(
model=MODEL,
url=API_URL,
token="2x0lifRJCpo8klurAJtRom"
)
amino3to1 = {
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}
# Covalent radii dictionary
COVALENT_RADIUS = {
"H": 0.31, "HE": 0.28, "LI": 1.28, "BE": 0.96, "B": 0.84, "C": 0.76, "N": 0.71, "O": 0.66, "F": 0.57, "NE": 0.58,
"NA": 1.66, "MG": 1.41, "AL": 1.21, "SI": 1.11, "P": 1.07, "S": 1.05, "CL": 1.02, "AR": 1.06, "K": 2.03,
"CA": 1.76, "SC": 1.7, "TI": 1.6, "V": 1.53, "CR": 1.39, "MN": 1.5, "FE": 1.42, "CO": 1.38, "NI": 1.24,
"CU": 1.32, "ZN": 1.22, "GA": 1.22, "GE": 1.2, "AS": 1.19, "SE": 1.2, "BR": 1.2, "KR": 1.16, "RB": 2.2,
"SR": 1.95, "Y": 1.9, "ZR": 1.75, "NB": 1.64, "MO": 1.54, "TC": 1.47, "RU": 1.46, "RH": 1.42, "PD": 1.39,
"AG": 1.45, "CD": 1.44, "IN": 1.42, "SN": 1.39, "SB": 1.39, "TE": 1.38, "I": 1.39, "XE": 1.4, "CS": 2.44,
"BA": 2.15, "LA": 2.07, "CE": 2.04, "PR": 2.03, "ND": 2.01, "PM": 1.99, "SM": 1.98, "EU": 1.98, "GD": 1.96,
"TB": 1.94, "DY": 1.92, "HO": 1.92, "ER": 1.89, "TM": 1.9, "YB": 1.87, "LU": 1.87, "HF": 1.75, "TA": 1.7,
"W": 1.62, "RE": 1.51, "OS": 1.44, "IR": 1.41, "PT": 1.36, "AU": 1.36, "HG": 1.32, "TL": 1.45, "PB": 1.46,
"BI": 1.48, "PO": 1.4, "AT": 1.5, "RN": 1.5, "FR": 2.6, "RA": 2.21, "AC": 2.15, "TH": 2.06, "PA": 2.0,
"U": 1.96, "NP": 1.9, "PU": 1.87, "AM": 1.8, "CM": 1.69, "BK": 2.0, "CF": 2.0, "ES": 2.0, "FM": 2.0,
"MD": 2.0, "NO": 2.0, "LR": 2.0, "RF": 2.0, "DB": 2.0, "SG": 2.0, "BH": 2.0, "HS": 2.0, "MT": 2.0,
"DS": 2.0, "RG": 2.0, "CN": 2.0, "UUT": 2.0, "UUQ": 2.0, "UUP": 2.0, "UUH": 2.0, "UUS": 2.0, "UUO": 2.0
}
# Function to get the covalent radius of an atom
def get_covalent_radius(atom):
element = atom.element.upper()
return COVALENT_RADIUS.get(element, 2.0) # Default to 2.0 Å if element is not in the dictionary
def calculate_clashes_for_pdb(pdb_file):
parser = PDBParser(QUIET=True)
structure = parser.get_structure("protein", pdb_file)
atoms = list(structure.get_atoms())
steric_clash_count = 0
num_atoms = len(atoms)
# Check atom pairs for steric clashes
for atom1, atom2 in itertools.combinations(atoms, 2):
covalent_radius_sum = get_covalent_radius(atom1) + get_covalent_radius(atom2)
distance = atom1 - atom2 # Distance between atom1 and atom2
# Check if the distance is less than the sum of covalent radii
if distance + 0.5 < covalent_radius_sum:
steric_clash_count += 1
# Normalize steric clashes per number of atoms
norm_ster_clash_count = steric_clash_count / num_atoms
return f"{steric_clash_count}", f"{norm_ster_clash_count}"
def read_pdb_io(pdb_file):
if isinstance(pdb_file, io.StringIO):
pdb_content = pdb_file.getvalue()
elif hasattr(pdb_file, 'name'):
with open(pdb_file.name, 'r') as f:
pdb_content = f.read()
else:
raise ValueError("Unsupported file type")
if not pdb_content.strip():
raise ValueError("The PDB file is empty.")
pdb_io = io.StringIO(pdb_content)
return pdb_io, pdb_content
def get_protein(pdb_file) -> ESMProtein:
try:
pdb_io, content = read_pdb_io(pdb_file)
if not content.strip():
raise ValueError("The PDB file is empty")
# Parse the PDB file using biotite
pdb_file = pdb.PDBFile.read(pdb_io)
structure = pdb_file.get_structure()
# Check if the structure contains any atoms
if structure.array_length() == 0:
raise ValueError("The PDB file does not contain any valid atoms")
# Filter for amino acids and create a sequence
valid_residues = []
for res in bs.residue_iter(structure):
res_name = res.res_name
if isinstance(res_name, np.ndarray):
res_name = res_name[0] # Take the first element if it's an array
if res_name in amino3to1:
valid_residues.append(res)
if not valid_residues:
raise ValueError("No valid amino acid residues found in the PDB file")
sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues)
# Handle res_id as a potential sequence
residue_indices = []
for res in valid_residues:
if isinstance(res.res_id, (list, tuple, np.ndarray)):
residue_indices.append(res.res_id[0]) # Take the first element if it's a sequence
else:
residue_indices.append(res.res_id)
# Create a ProteinChain object
protein_chain = ProteinChain(
id="test",
sequence=sequence,
chain_id="A",
entity_id=None,
residue_index=np.array(residue_indices, dtype=int),
insertion_code=np.full(len(sequence), "", dtype="<U4"),
atom37_positions=np.full((len(sequence), 37, 3), np.nan),
atom37_mask=np.zeros((len(sequence), 37), dtype=bool),
confidence=np.ones(len(sequence), dtype=np.float32)
)
# Fill in atom positions and mask
for i, res in enumerate(valid_residues):
for atom in res:
atom_name = atom.atom_name
if isinstance(atom_name, np.ndarray):
atom_name = atom_name[0] # Take the first element if it's an array
if atom_name in RC.atom_order:
idx = RC.atom_order[atom_name]
coord = atom.coord
if coord.ndim > 1:
coord = coord[0] # Take the first coordinate set if multiple are present
protein_chain.atom37_positions[i, idx] = coord
protein_chain.atom37_mask[i, idx] = True
protein = ESMProtein.from_protein_chain(protein_chain)
return protein
except Exception as e:
print(f"Error processing PDB file: {str(e)}")
raise ValueError(f"Unable to process the PDB file: {str(e)}")
def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein:
"""Add Gaussian noise to the atom positions of the protein."""
coordinates = protein.coordinates
noise = np.random.randn(*coordinates.shape) * noise_level
noisy_coordinates = coordinates + noise
return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates)
def run_structure_prediction(protein: ESMProtein, temperature: float, num_steps: int) -> ESMProtein:
structure_prediction_config = GenerationConfig(
track="structure",
num_steps=num_steps,
temperature=temperature,
)
try:
response = model.generate(protein, structure_prediction_config)
if isinstance(response, ESMProtein):
return response
elif isinstance(response, ESMProteinError):
print(f"ESMProteinError during structure prediction: {response.error_msg}")
return None
else:
raise ValueError(f"Unexpected response type: {type(response)}")
except Exception as e:
print(f"Error during structure prediction: {str(e)}")
return None
@spaces.GPU
def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]:
if structure_prediction is None:
return None, float('inf')
try:
structure_prediction_chain = structure_prediction.to_protein_chain()
protein_chain = protein.to_protein_chain()
# Ensure both chains have the same length
min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence))
structure_indices = np.arange(0, min_length)
# Perform alignment
aligned_chain = structure_prediction_chain.align(
protein_chain,
mobile_inds=structure_indices,
target_inds=structure_indices
)
# Calculate RMSD
crmsd = structure_prediction_chain.rmsd(
protein_chain,
mobile_inds=structure_indices,
target_inds=structure_indices
)
return ESMProtein.from_protein_chain(aligned_chain), crmsd
except AttributeError as e:
print(f"Error during alignment: {str(e)}")
print(f"Structure prediction type: {type(structure_prediction)}")
print(f"Structure prediction attributes: {dir(structure_prediction)}")
return None, float('inf')
except Exception as e:
print(f"Unexpected error during alignment: {str(e)}")
return None, float('inf')
@spaces.GPU
def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein):
if aligned is None:
return None
view = py3Dmol.view(width=800, height=600)
view.addModel(protein_to_pdb(protein), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addModel(protein_to_pdb(aligned), "pdb")
view.setStyle({"model": -1}, {"cartoon": {"color": "lightgreen"}})
view.zoomTo()
return view
def protein_to_pdb(protein: ESMProtein):
pdb_str = ""
for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)):
for j, atom in enumerate(RC.atom_types):
if not torch.isnan(coords[j][0]):
x, y, z = coords[j].tolist()
pdb_str += f"ATOM {i*37+j+1:5d} {atom:3s} {aa:3s} A{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f}\n"
return pdb_str
@spaces.GPU
def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int, temperature: float, num_steps: int, progress=gr.Progress()):
protein = get_protein(pdb_file)
runs = []
total_iterations = num_frames * num_runs
progress(0, desc="Starting predictions")
for frame in progress.tqdm(range(num_frames), desc="Processing frames"):
noisy_protein = add_noise_to_coordinates(protein, noise_level)
for i in range(num_runs):
progress((frame * num_runs + i + 1) / total_iterations, desc=f"Frame {frame+1}, Run {i+1}")
structure_prediction = run_structure_prediction(noisy_protein, temperature, num_steps)
if structure_prediction is not None:
aligned, crmsd = align_after_prediction(protein, structure_prediction)
if aligned is not None:
runs.append((crmsd, aligned))
time.sleep(0.1) # Small delay to allow for UI updates
if not runs:
return None, "No successful predictions"
best_aligned = sorted(runs, key=lambda x: x[0])[0]
view_data = visualize_after_pred(protein, best_aligned[1])
return view_data, f"Best cRMSD: {best_aligned[0]:.4f}"
def run_prediction(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress=gr.Progress()):
try:
if pdb_file is None:
return "Please upload a PDB file.", "No file uploaded", "", ""
progress(0, desc="Starting prediction")
view, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress)
steric_clash_text, norm_steric_clash_text = calculate_clashes_for_pdb(pdb_file)
if view is None:
return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text, steric_clash_text, norm_steric_clash_text
progress(0.9, desc="Rendering visualization")
# Convert the py3Dmol view to HTML
view_html = view._make_html().replace("'", '"')
html_content = f"""
<iframe style="width: 100%; height: 600px;" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='<!DOCTYPE html><html>{view_html}</html>'></iframe>
"""
progress(1.0, desc="Completed")
return html_content, crmsd_text, steric_clash_text, norm_steric_clash_text
except Exception as e:
error_message = str(e)
stack_trace = traceback.format_exc()
return f"""
<div style='color: red;'>
<h3>Error:</h3>
<p>{error_message}</p>
<h4>Stack Trace:</h4>
<pre>{stack_trace}</pre>
</div>
""", "Error occurred", "", ""
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames")
with gr.Accordion(label='learn more about MISATO ESM3 conformational sampling', open=False):
with gr.Row():
with gr.Column():
gr.Markdown(about)
with gr.Column():
gr.Markdown(howtouse)
with gr.Row():
gr.Markdown(about1)
with gr.Accordion(label="watch presentation video", open=False):
with gr.Row():
gr.Video(value="demovideo/demo.mp4", label="MISATO Video Submission")
with gr.Row():
with gr.Column(scale=1):
pdb_file = gr.File(label="Upload PDB file")
num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame")
noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level")
num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames")
temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.7, label="Temperature")
num_steps = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Number of steps")
run_button = gr.Button("Run Prediction")
with gr.Column(scale=2):
visualization = gr.HTML(label="3D Visualization")
alignment_result = gr.Textbox(label="Alignment Result")
steric_clash_result = gr.Textbox(label="Steric Clash Result")
norm_steric_clash_result = gr.Textbox(label="Normalized Steric Clash Result")
run_button.click(
fn=run_prediction,
inputs=[pdb_file, num_runs, noise_level, num_frames, temperature, num_steps],
outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result]
)
gr.Examples(
examples=[
["examples/1ywi.pdb"],
["examples/5awl.pdb"],
["examples/11gs.pdb"],
],
inputs=[pdb_file],
outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result],
fn=run_prediction,
cache_examples=False,
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch() |