torchdrug / app.py
jannisborn's picture
Duplicate from jannisborn/gt4sd-moler
7d76d6f
raw
history blame
No virus
2.38 kB
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.moler import MoLeR, MoLeRDefaultGenerator
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
TITLE = "MoLeR"
def run_inference(
algorithm_version: str,
scaffolds: str,
beam_size: int,
number_of_samples: int,
seed: int,
):
config = MoLeRDefaultGenerator(
algorithm_version=algorithm_version,
scaffolds=scaffolds,
beam_size=beam_size,
num_samples=4,
seed=seed,
num_workers=1,
)
model = MoLeR(configuration=config)
samples = list(model.sample(number_of_samples))
seed_mols = [] if scaffolds == "" else scaffolds.split(".")
return draw_grid_generate(seed_mols, samples)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: TITLE in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="MoLeR (MOlecule-LEvel Representation)",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="v0"),
gr.Textbox(
label="Scaffolds",
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
lines=1,
),
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beam_size"),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of samples", step=1
),
gr.Number(value=42, label="Seed", precision=0),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)