File size: 17,760 Bytes
a653c8d
1f1f2d8
 
 
 
 
 
f74eb0b
f2448d2
1f1f2d8
 
 
 
 
f2448d2
ab3b1db
f2448d2
86c05eb
 
afd257c
 
ab3b1db
21e6511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3880c
 
21e6511
 
 
 
 
ab3b1db
f2448d2
 
 
 
 
f74eb0b
 
 
f2448d2
 
 
f74eb0b
1f1f2d8
 
 
 
 
 
 
 
f2448d2
afd257c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335fce6
afd257c
1f1f2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74eb0b
1f1f2d8
 
 
21e6511
63a236d
 
21e6511
 
63a236d
 
f2448d2
 
 
 
 
 
 
63a236d
f2448d2
63a236d
 
 
ca86d3c
 
63a236d
 
 
 
 
 
 
f2448d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a236d
 
 
 
 
 
f2448d2
 
 
ca86d3c
 
f2448d2
 
86c05eb
 
fbb545c
 
 
 
 
 
f2448d2
fbb545c
f2448d2
 
 
 
 
 
 
 
 
63a236d
ca86d3c
21e6511
1f1f2d8
 
 
86c05eb
 
 
 
1f1f2d8
 
 
86c05eb
21e6511
f2448d2
 
 
 
86c05eb
63a236d
 
 
1f1f2d8
f2448d2
86c05eb
 
1f1f2d8
21e6511
1f1f2d8
 
9b5ec9d
1f1f2d8
86c05eb
21e6511
9b5ec9d
fbb545c
9b5ec9d
63a236d
86c05eb
fbb545c
 
 
86c05eb
fbb545c
 
 
 
 
86c05eb
fbb545c
86c05eb
9b5ec9d
1f1f2d8
 
 
 
 
 
 
 
 
 
9b5ec9d
 
1f1f2d8
 
 
ba8add4
21e6511
 
 
 
 
db3880c
 
 
 
 
1f1f2d8
 
 
 
 
 
21e6511
 
1f1f2d8
 
 
 
 
9b5ec9d
 
1f1f2d8
 
 
21e6511
9b5ec9d
1f1f2d8
21e6511
 
 
 
 
 
 
 
 
 
 
1f1f2d8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import spaces
import gradio as gr
import py3Dmol
import io
import numpy as np
import os
import traceback
from esm.sdk import client
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError
from esm.utils.structure.protein_chain import ProteinChain
from Bio.Data import PDBData
import biotite.structure as bs
from biotite.structure.io import pdb
from esm.utils import residue_constants as RC
import requests
from dotenv import load_dotenv
import torch
import json 
import time
from Bio.PDB import PDBParser
import itertools

howtouse = """
        ## How to use
        1. Upload a PDB file using the file uploader.
        2. Adjust the number of prediction runs per frame using the slider.
        3. Set the noise level to add random perturbations to the structure.
        4. Choose the number of MD frames to simulate.
        5. Click the "Run Prediction" button to start the process.
        6. The 3D visualization will show the original structure (grey) and the best predicted structure (green).
        7. The alignment result will display the best cRMSD (lower is better).
        8. Total and Normalized (per atom) steric clashes (lower is better)
        """
about = """ ## Background

- 3D protein structures typically come from crystal structures, which are densely packed and lack flexibility.
- Different proteins require varying levels of noise to achieve overlap in conformational space.
- We've developed an adaptability model that predicts the appropriate noise level for each protein.

## Our Approach

1. **Adaptability Model**: Trained on Molecular Dynamics (MD) data, our model predicts flexibility at the atomic level.
2. **Correlation**: The adaptability predictions correlate well with the RMSD (Root Mean Square Deviation) from ESM3 sampling.
3. **Noise Application**: We apply noise to simulate protein flexibility, mimicking MD-like behavior.
"""
about1 = """
## About
        This demo uses the ESM3 model to predict protein structures from PDB files. 
        It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD.
"""

load_dotenv()

API_URL = "https://forge.evolutionaryscale.ai/api/v1"
MODEL = "esm3-open-2024-03"
API_TOKEN = os.environ.get("ESM_API_TOKEN")
if not API_TOKEN:
    raise ValueError("ESM_API_TOKEN environment variable is not set")

model = client(
    model=MODEL,
    url=API_URL,
    token="2x0lifRJCpo8klurAJtRom"
)

amino3to1 = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
    'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
    'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
    'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}


# Covalent radii dictionary
COVALENT_RADIUS = {
    "H": 0.31, "HE": 0.28, "LI": 1.28, "BE": 0.96, "B": 0.84, "C": 0.76, "N": 0.71, "O": 0.66, "F": 0.57, "NE": 0.58,
    "NA": 1.66, "MG": 1.41, "AL": 1.21, "SI": 1.11, "P": 1.07, "S": 1.05, "CL": 1.02, "AR": 1.06, "K": 2.03,
    "CA": 1.76, "SC": 1.7, "TI": 1.6, "V": 1.53, "CR": 1.39, "MN": 1.5, "FE": 1.42, "CO": 1.38, "NI": 1.24,
    "CU": 1.32, "ZN": 1.22, "GA": 1.22, "GE": 1.2, "AS": 1.19, "SE": 1.2, "BR": 1.2, "KR": 1.16, "RB": 2.2,
    "SR": 1.95, "Y": 1.9, "ZR": 1.75, "NB": 1.64, "MO": 1.54, "TC": 1.47, "RU": 1.46, "RH": 1.42, "PD": 1.39,
    "AG": 1.45, "CD": 1.44, "IN": 1.42, "SN": 1.39, "SB": 1.39, "TE": 1.38, "I": 1.39, "XE": 1.4, "CS": 2.44,
    "BA": 2.15, "LA": 2.07, "CE": 2.04, "PR": 2.03, "ND": 2.01, "PM": 1.99, "SM": 1.98, "EU": 1.98, "GD": 1.96,
    "TB": 1.94, "DY": 1.92, "HO": 1.92, "ER": 1.89, "TM": 1.9, "YB": 1.87, "LU": 1.87, "HF": 1.75, "TA": 1.7,
    "W": 1.62, "RE": 1.51, "OS": 1.44, "IR": 1.41, "PT": 1.36, "AU": 1.36, "HG": 1.32, "TL": 1.45, "PB": 1.46,
    "BI": 1.48, "PO": 1.4, "AT": 1.5, "RN": 1.5, "FR": 2.6, "RA": 2.21, "AC": 2.15, "TH": 2.06, "PA": 2.0,
    "U": 1.96, "NP": 1.9, "PU": 1.87, "AM": 1.8, "CM": 1.69, "BK": 2.0, "CF": 2.0, "ES": 2.0, "FM": 2.0,
    "MD": 2.0, "NO": 2.0, "LR": 2.0, "RF": 2.0, "DB": 2.0, "SG": 2.0, "BH": 2.0, "HS": 2.0, "MT": 2.0,
    "DS": 2.0, "RG": 2.0, "CN": 2.0, "UUT": 2.0, "UUQ": 2.0, "UUP": 2.0, "UUH": 2.0, "UUS": 2.0, "UUO": 2.0
}

# Function to get the covalent radius of an atom
def get_covalent_radius(atom):
    element = atom.element.upper()
    return COVALENT_RADIUS.get(element, 2.0)  # Default to 2.0 Å if element is not in the dictionary

def calculate_clashes_for_pdb(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)
    atoms = list(structure.get_atoms())
    steric_clash_count = 0

    num_atoms = len(atoms)

    # Check atom pairs for steric clashes
    for atom1, atom2 in itertools.combinations(atoms, 2):
        covalent_radius_sum = get_covalent_radius(atom1) + get_covalent_radius(atom2)
        distance = atom1 - atom2  # Distance between atom1 and atom2

        # Check if the distance is less than the sum of covalent radii
        if distance + 0.5 < covalent_radius_sum:
            steric_clash_count += 1

    # Normalize steric clashes per number of atoms
    norm_ster_clash_count = steric_clash_count / num_atoms
    return f"{steric_clash_count}", f"{norm_ster_clash_count}"

def read_pdb_io(pdb_file):
    if isinstance(pdb_file, io.StringIO):
        pdb_content = pdb_file.getvalue()
    elif hasattr(pdb_file, 'name'):
        with open(pdb_file.name, 'r') as f:
            pdb_content = f.read()
    else:
        raise ValueError("Unsupported file type")
    
    if not pdb_content.strip():
        raise ValueError("The PDB file is empty.")
    
    pdb_io = io.StringIO(pdb_content)
    return pdb_io, pdb_content

def get_protein(pdb_file) -> ESMProtein:
    try:
        pdb_io, content = read_pdb_io(pdb_file)
        
        if not content.strip():
            raise ValueError("The PDB file is empty")
        
        # Parse the PDB file using biotite
        pdb_file = pdb.PDBFile.read(pdb_io)
        structure = pdb_file.get_structure()
        
        # Check if the structure contains any atoms
        if structure.array_length() == 0:
            raise ValueError("The PDB file does not contain any valid atoms")
        
        # Filter for amino acids and create a sequence
        valid_residues = []
        for res in bs.residue_iter(structure):
            res_name = res.res_name
            if isinstance(res_name, np.ndarray):
                res_name = res_name[0]  # Take the first element if it's an array
            if res_name in amino3to1:
                valid_residues.append(res)
        
        if not valid_residues:
            raise ValueError("No valid amino acid residues found in the PDB file")
        
        sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues)
        
        # Handle res_id as a potential sequence
        residue_indices = []
        for res in valid_residues:
            if isinstance(res.res_id, (list, tuple, np.ndarray)):
                residue_indices.append(res.res_id[0])  # Take the first element if it's a sequence
            else:
                residue_indices.append(res.res_id)
        
        # Create a ProteinChain object
        protein_chain = ProteinChain(
            id="test",
            sequence=sequence,
            chain_id="A",
            entity_id=None,
            residue_index=np.array(residue_indices, dtype=int),
            insertion_code=np.full(len(sequence), "", dtype="<U4"),
            atom37_positions=np.full((len(sequence), 37, 3), np.nan),
            atom37_mask=np.zeros((len(sequence), 37), dtype=bool),
            confidence=np.ones(len(sequence), dtype=np.float32)
        )
        
        # Fill in atom positions and mask
        for i, res in enumerate(valid_residues):
            for atom in res:
                atom_name = atom.atom_name
                if isinstance(atom_name, np.ndarray):
                    atom_name = atom_name[0]  # Take the first element if it's an array
                if atom_name in RC.atom_order:
                    idx = RC.atom_order[atom_name]
                    coord = atom.coord
                    if coord.ndim > 1:
                        coord = coord[0]  # Take the first coordinate set if multiple are present
                    protein_chain.atom37_positions[i, idx] = coord
                    protein_chain.atom37_mask[i, idx] = True
        
        protein = ESMProtein.from_protein_chain(protein_chain)
        return protein
    except Exception as e:
        print(f"Error processing PDB file: {str(e)}")
        raise ValueError(f"Unable to process the PDB file: {str(e)}")

def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein:
    """Add Gaussian noise to the atom positions of the protein."""
    coordinates = protein.coordinates
    noise = np.random.randn(*coordinates.shape) * noise_level
    noisy_coordinates = coordinates + noise
    return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates)

def run_structure_prediction(protein: ESMProtein, temperature: float, num_steps: int) -> ESMProtein:
    structure_prediction_config = GenerationConfig(
        track="structure",
        num_steps=num_steps,
        temperature=temperature,
    )
    try:
        response = model.generate(protein, structure_prediction_config)
        
        if isinstance(response, ESMProtein):
            return response
        elif isinstance(response, ESMProteinError):
            print(f"ESMProteinError during structure prediction: {response.error_msg}")
            return None
        else:
            raise ValueError(f"Unexpected response type: {type(response)}")
    except Exception as e:
        print(f"Error during structure prediction: {str(e)}")
        return None
    
@spaces.GPU
def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]:
    if structure_prediction is None:
        return None, float('inf')
    
    try:
        structure_prediction_chain = structure_prediction.to_protein_chain()
        protein_chain = protein.to_protein_chain()
        
        # Ensure both chains have the same length
        min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence))
        structure_indices = np.arange(0, min_length)
        
        # Perform alignment
        aligned_chain = structure_prediction_chain.align(
            protein_chain, 
            mobile_inds=structure_indices, 
            target_inds=structure_indices
        )
        
        # Calculate RMSD
        crmsd = structure_prediction_chain.rmsd(
            protein_chain, 
            mobile_inds=structure_indices, 
            target_inds=structure_indices
        )
        
        return ESMProtein.from_protein_chain(aligned_chain), crmsd
    except AttributeError as e:
        print(f"Error during alignment: {str(e)}")
        print(f"Structure prediction type: {type(structure_prediction)}")
        print(f"Structure prediction attributes: {dir(structure_prediction)}")
        return None, float('inf')
    except Exception as e:
        print(f"Unexpected error during alignment: {str(e)}")
        return None, float('inf')

@spaces.GPU
def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein):
    if aligned is None:
        return None
    
    view = py3Dmol.view(width=800, height=600)
    view.addModel(protein_to_pdb(protein), "pdb")
    view.setStyle({"cartoon": {"color": "lightgrey"}})
    view.addModel(protein_to_pdb(aligned), "pdb")
    view.setStyle({"model": -1}, {"cartoon": {"color": "lightgreen"}})
    view.zoomTo()
    
    return view

def protein_to_pdb(protein: ESMProtein):
    pdb_str = ""
    for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)):
        for j, atom in enumerate(RC.atom_types):
            if not torch.isnan(coords[j][0]):
                x, y, z = coords[j].tolist()
                pdb_str += f"ATOM  {i*37+j+1:5d}  {atom:3s} {aa:3s} A{i+1:4d}    {x:8.3f}{y:8.3f}{z:8.3f}\n"
    return pdb_str

@spaces.GPU
def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int, temperature: float, num_steps: int, progress=gr.Progress()):
    protein = get_protein(pdb_file)
    runs = []
    
    total_iterations = num_frames * num_runs
    progress(0, desc="Starting predictions")
    
    for frame in progress.tqdm(range(num_frames), desc="Processing frames"):
        noisy_protein = add_noise_to_coordinates(protein, noise_level)
        
        for i in range(num_runs):
            progress((frame * num_runs + i + 1) / total_iterations, desc=f"Frame {frame+1}, Run {i+1}")
            structure_prediction = run_structure_prediction(noisy_protein, temperature, num_steps)
            if structure_prediction is not None:
                aligned, crmsd = align_after_prediction(protein, structure_prediction)
                if aligned is not None:
                    runs.append((crmsd, aligned))
            time.sleep(0.1)  # Small delay to allow for UI updates
    
    if not runs:
        return None, "No successful predictions"
    
    best_aligned = sorted(runs, key=lambda x: x[0])[0]
    view_data = visualize_after_pred(protein, best_aligned[1])
    return view_data, f"Best cRMSD: {best_aligned[0]:.4f}"

def run_prediction(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress=gr.Progress()):
    try:
        if pdb_file is None:
            return "Please upload a PDB file.", "No file uploaded", "", ""
        
        progress(0, desc="Starting prediction")
        view, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames, temperature, num_steps, progress)
        steric_clash_text, norm_steric_clash_text = calculate_clashes_for_pdb(pdb_file)
        if view is None:
            return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text, steric_clash_text, norm_steric_clash_text
        
        progress(0.9, desc="Rendering visualization")
        
        # Convert the py3Dmol view to HTML
        view_html = view._make_html().replace("'", '"')
        html_content = f"""
        <iframe style="width: 100%; height: 600px;" name="result" allow="midi; geolocation; microphone; camera; 
        display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
        allow-scripts allow-same-origin allow-popups 
        allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
        allowpaymentrequest="" frameborder="0" srcdoc='<!DOCTYPE html><html>{view_html}</html>'></iframe>
        """
        
        progress(1.0, desc="Completed")
        return html_content, crmsd_text, steric_clash_text, norm_steric_clash_text
    except Exception as e:
        error_message = str(e)
        stack_trace = traceback.format_exc()
        return f"""
        <div style='color: red;'>
            <h3>Error:</h3>
            <p>{error_message}</p>
            <h4>Stack Trace:</h4>
            <pre>{stack_trace}</pre>
        </div>
        """, "Error occurred", "", ""
    
def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames")
        with gr.Accordion(label='learn more about MISATO ESM3 conformational sampling', open=False):
            with gr.Row():
                with gr.Column():
                    gr.Markdown(about)
                with gr.Column():
                    gr.Markdown(howtouse)
            with gr.Row():
                gr.Markdown(about1)
        with gr.Accordion(label="watch presentation video", open=False):
            with gr.Row():
                gr.Video(value="demovideo/demo.mp4", label="MISATO Video Submission")
        with gr.Row():
            with gr.Column(scale=1):
                pdb_file = gr.File(label="Upload PDB file")
                num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame")
                noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level")
                num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames")
                temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.7, label="Temperature")
                num_steps = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Number of steps")
                run_button = gr.Button("Run Prediction")
            
            with gr.Column(scale=2):
                visualization = gr.HTML(label="3D Visualization")
                alignment_result = gr.Textbox(label="Alignment Result")
                steric_clash_result = gr.Textbox(label="Steric Clash Result")
                norm_steric_clash_result = gr.Textbox(label="Normalized Steric Clash Result")
        
        run_button.click(
            fn=run_prediction,
            inputs=[pdb_file, num_runs, noise_level, num_frames, temperature, num_steps],
            outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result]
        )
        gr.Examples(
            examples=[
                ["examples/1ywi.pdb"],
                ["examples/5awl.pdb"],
                ["examples/11gs.pdb"],
            ],
            inputs=[pdb_file],
            outputs=[visualization, alignment_result, steric_clash_result, norm_steric_clash_result],
            fn=run_prediction,
            cache_examples=False,
        )

    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.queue()
    demo.launch()