simonduerr's picture
Update app.py
02a1c15
import gradio as gr
import numpy as np
import os
import ray
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import pipeline as pl
from GPUtil import showUtilization as gpu_usage
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import plotly.graph_objects as go
import torch
import gc
import jax
from numba import cuda
print('GPU available',torch.cuda.is_available())
#print('__CUDA Device Name:',torch.cuda.get_device_name(0))
print(os.getcwd())
if "/home/user/app/alphafold" not in sys.path:
sys.path.append("/home/user/app/alphafold")
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
def mk_mock_template(query_sequence):
"""create blank template"""
ln = len(query_sequence)
output_templates_sequence = "-" * ln
templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3)
)
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
)
template_features = {
"template_all_atom_positions": templates_all_atom_positions[None],
"template_all_atom_masks": templates_all_atom_masks[None],
"template_aatype": np.array(templates_aatype)[None],
"template_domain_names": [f"none".encode()],
}
return template_features
def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
"""Predicts structure using AlphaFold for the given sequence."""
# Run the models.
# currently we only run model1
plddts = {}
for model_name, model_runner in model_runners.items():
processed_feature_dict = model_runner.process_features(
feature_dict, random_seed=random_seed
)
prediction_result = model_runner.predict(processed_feature_dict)
b_factors = (
prediction_result["plddt"][:, None]
* prediction_result["structure_module"]["final_atom_mask"]
)
unrelaxed_protein = protein.from_prediction(
processed_feature_dict, prediction_result, b_factors
)
unrelaxed_pdb_path = f"/home/user/app/{prefix}_unrelaxed_{model_name}.pdb"
plddts[model_name] = prediction_result["plddt"]
print(f"{model_name} {plddts[model_name].mean()}")
with open(unrelaxed_pdb_path, "w") as f:
f.write(protein.to_pdb(unrelaxed_protein))
return plddts
@ray.remote(num_gpus=0, max_calls=1)
def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
print("running protgpt2")
print(gpu_usage())
protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
sequences = protgpt2(
startsequence,
max_length=length,
do_sample=True,
top_k=top_k_poolsize,
repetition_penalty=repetitionPenalty,
num_return_sequences=max_seqs,
eos_token_id=0,
)
print("Cleaning up after protGPT2")
#print(gpu_usage())
#torch.cuda.empty_cache()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return sequences
@ray.remote(num_gpus=0, max_calls=1)
def run_alphafold(startsequence):
print(gpu_usage())
model_runners = {}
models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
for model_name in models:
model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = 1
model_params = data.get_model_haiku_params(model_name=model_name, data_dir="/home/user/app/")
model_runner = model.RunModel(model_config, model_params)
model_runners[model_name] = model_runner
query_sequence = startsequence.replace("\n", "")
feature_dict = {
**pipeline.make_sequence_features(
sequence=query_sequence, description="none", num_res=len(query_sequence)
),
**pipeline.make_msa_features(
msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]]
),
**mk_mock_template(query_sequence),
}
plddts = predict_structure("test", feature_dict, model_runners)
print("AF2 done")
#backend = jax.lib.xla_bridge.get_backend()
#for buf in backend.live_buffers(): buf.delete()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return plddts["model_1"]
def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
startsequence = inp
seqlen = length
generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs))
gen_seqs = [x["generated_text"] for x in generated_seqs]
print(gen_seqs)
sequencestxt = ""
for i, seq in enumerate(gen_seqs):
s = seq.replace("\n","")
seqlen = len(s)
s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
sequencestxt +=f">seq{i}, {seqlen} residues \n{s}\n\n"
return sequencestxt
def update(inp):
print("Running AF on", inp)
startsequence = inp
# run alphafold using ray
plddts = ray.get(run_alphafold.remote(startsequence))
print(plddts)
x = np.arange(10)
#plt.style.use(["seaborn-ticks", "seaborn-talk"])
#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.plot(plddts)
#ax.set_ylabel("predicted LDDT")
#ax.set_xlabel("positions")
#ax.set_title("pLDDT")
fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}'))
fig.update_layout(title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white")
return (
molecule(
f"test_unrelaxed_model_1.pdb",
),
fig,
f"{np.mean(plddts):.1f} Β± {np.std(plddts):.1f}",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb):
mol = read_mol(pdb)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<link rel="stylesheet" href="https://unpkg.com/flowbite@1.4.5/dist/flowbite.min.css" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 800px;
position: relative;
}
.space-x-2 > * + *{
margin-left: 0.5rem;
}
.p-1{
padding:0.5rem;
}
.flex{
display:flex;
align-items: center;
}
.w-4{
width:1rem;
}
.h-4{
height:1rem;
}
.mt-4{
margin-top:1rem;
}
</style>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<div class="flex">
<div class="px-4">
<label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer ">
<input id="sidechain"type="checkbox" class="sr-only peer">
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span>
</label>
</div>
<button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download">
<svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg>
Download predicted structure
</button>
</div>
<div class="text-sm">
<div class="font-medium mt-4"><b>AlphaFold model confidence:</b></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(0, 83, 214);">&nbsp;</span><span class="legendlabel">Very high
(pLDDT &gt; 90)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(101, 203, 243);">&nbsp;</span><span class="legendlabel">Confident
(90 &gt; pLDDT &gt; 70)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(255, 219, 19);">&nbsp;</span><span class="legendlabel">Low (70 &gt;
pLDDT &gt; 50)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(255, 125, 69);">&nbsp;</span><span class="legendlabel">Very low
(pLDDT &lt; 50)</span></div>
<div class="row column legendDesc"> AlphaFold produces a per-residue confidence
score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation.
</div>
</div>
<script>
let viewer = null;
let voldata = null;
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
viewer = $3Dmol.createViewer( element, config );
viewer.ui.initiateUI();
let data = `"""
+ mol
+ """`
viewer.addModel( data, "pdb" );
//AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96
let colorAlpha = function (atom) {
if (atom.b < 50) {
return "OrangeRed";
} else if (atom.b < 70) {
return "Gold";
} else if (atom.b < 90) {
return "MediumTurquoise";
} else {
return "Blue";
}
};
viewer.setStyle({}, { cartoon: { colorfunc: colorAlpha } });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
viewer.getModel(0).setHoverable({}, true,
function (atom, viewer, event, container) {
console.log(atom)
if (!atom.label) {
atom.label = viewer.addLabel(atom.resn+atom.resi+" pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" });
}
},
function (atom, viewer) {
if (atom.label) {
viewer.removeLabel(atom.label);
delete atom.label;
}
}
);
$("#sidechain").change(function () {
if (this.checked) {
BB = ["C", "O", "N"]
viewer.setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }});
viewer.render()
} else {
viewer.setStyle({cartoon: { colorfunc: colorAlpha }});
viewer.render()
}
});
$("#download").click(function () {
download("gradioFold_model1.pdb", data);
})
});
function download(filename, text) {
var element = document.createElement("a");
element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text));
element.setAttribute("download", filename);
element.style.display = "none";
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
</script>
</body></html>"""
)
return f"""<iframe style="width: 800px; height: 1200px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def change_sequence(chosenSeq):
return chosenSeq
proteindream = gr.Blocks()
with proteindream:
gr.Markdown("# GradioFold")
gr.Markdown(
"""GradioFold is a web-based tool that combines a large language model trained on natural protein sequence (protGPT2) with structure prediction using AlphaFold.
Type a start sequence that protGPT2 can complete or let protGPT2 generate a complete sequence without a start token."""
)
gr.Markdown("## protGPT2")
gr.Markdown(
"""
Enter a start sequence and have the language model complete it OR leave empty.
"""
)
with gr.Box():
with gr.Row():
inp = gr.Textbox(placeholder="M", label="Start sequence")
length = gr.Number(value=50, label="Max sequence length")
with gr.Row():
repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty")
top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size")
max_seqs = gr.Slider(minimum=2, maximum=20,value=5, step=1, label="Number of sequences to generate")
btn = gr.Button("Predict sequences using protGPT2")
results = gr.Textbox(label="Results", lines=15)
btn.click(fn=update_protGPT2, inputs=[inp, length, repetitionPenalty, top_k_poolsize, max_seqs], outputs=results)
gr.Markdown("## AlphaFold")
gr.Markdown(
"Select a generated sequence above and copy it in the field below for structure prediction using AlphaFold2. You can also edit the sequence. Predictions will take around 2-5 minutes to be processed. Proteins larger than about 1000 residues will not fit into memory."
)
with gr.Group():
chosenSeq = gr.Textbox(label="Chosen sequence")
btn2 = gr.Button("Predict structure")
with gr.Group():
meanpLDDT = gr.Textbox(label="Mean pLDDT of chosen sequence")
with gr.Row():
mol = gr.HTML()
plot = gr.Plot(label="pLDDT")
gr.Markdown(
"""## Acknowledgements
More information about the used algorithms can be found below.
All code is available on [Huggingface](https://huggingface.co/spaces/simonduerr/protGPT2_gradioFold/blob/main) and licensed under MIT license.
- ProtGPT2: Ferruz et.al πŸ“„[BioRxiv](https://doi.org/10.1101/2022.03.09.483666) πŸ’»[Code](https://huggingface.co/nferruz/ProtGPT2)
- AlphaFold2: Jumper et.al πŸ“„[Paper](https://doi.org/10.1038/s41586-021-03819-2) πŸ’»[Code](https://github.com/deepmind/alphafold) Model parameters released under CC BY 4.0
- ColabFold: Mirdita et.al πŸ“„[Paper](https://doi.org/10.1101/2021.08.15.456425 ) πŸ’»[Code](https://github.com/sokrypton/ColabFold)
- 3Dmol.js: Rego & Koes πŸ“„[Paper](https://academic.oup.com/bioinformatics/article/31/8/1322/213186) πŸ’» [Code](https://github.com/3dmol/3Dmol.js)
Created by [@simonduerr](https://twitter.com/simonduerr)
Thanks to Hugginface team for sponsoring a free GPU for this demo.
"""
)
#seqChoice.change(fn=update_seqs, inputs=seqChoice, outputs=chosenSeq)
btn2.click(fn=update, inputs=chosenSeq, outputs=[mol, plot, meanpLDDT])
ray.init(runtime_env={"working_dir": "./alphafold"})
proteindream.launch(share=False)