|
import bz2 |
|
import shutil |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import pypythia.msa |
|
import pypythia.prediction |
|
import pypythia.predictor |
|
import pypythia.raxmlng |
|
|
|
|
|
def get_default_raxmlng(): |
|
version = "1.1.0" |
|
uncompressed_raxmlng = Path.home() / f"raxml-ng-v{version}-linux-64" |
|
if not uncompressed_raxmlng.exists(): |
|
compressed_raxmlng = Path(__file__).parent / f"raxml-ng-v{version}-linux-64.bz2" |
|
with bz2.BZ2File(compressed_raxmlng) as bz, uncompressed_raxmlng.open( |
|
"wb" |
|
) as rax: |
|
shutil.copyfileobj(bz, rax) |
|
return uncompressed_raxmlng |
|
|
|
|
|
def predict_difficulty(uploaded_file): |
|
predictor_file = ( |
|
Path(pypythia.__file__).parent / "predictors" / "predictor_lgb_v1.0.0.pckl" |
|
) |
|
predictor = pypythia.predictor.DifficultyPredictor(predictor_file.open("rb")) |
|
raxmlng = pypythia.raxmlng.RAxMLNG( |
|
shutil.which("raxml-ng") or get_default_raxmlng() |
|
) |
|
with tempfile.NamedTemporaryFile() as msa_file: |
|
uploaded_file.seek(0) |
|
shutil.copyfileobj(uploaded_file, msa_file) |
|
msa_file.flush() |
|
msa = pypythia.msa.MSA(msa_file.name) |
|
msa_features = pypythia.prediction.get_all_features(raxmlng, msa) |
|
difficulty = predictor.predict(msa_features) |
|
|
|
return difficulty, msa_features |
|
|
|
|
|
pythia_demo = gr.Interface( |
|
predict_difficulty, |
|
gr.File(label="MSA file (.phy or .msa)"), |
|
[ |
|
gr.Number(label="Difficulty", precision=5), |
|
gr.JSON(label="Features used for prediction"), |
|
], |
|
) |
|
pythia_demo.launch() |
|
|