moler / app.py
jannisborn's picture
update
7eafb33 unverified
raw
history blame
2.38 kB
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.moler import MoLeR, MoLeRDefaultGenerator
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
TITLE = "MoLeR"
def run_inference(
algorithm_version: str,
scaffolds: str,
beam_size: int,
number_of_samples: int,
seed: int,
):
config = MoLeRDefaultGenerator(
algorithm_version=algorithm_version,
scaffolds=scaffolds,
beam_size=beam_size,
num_samples=4,
seed=seed,
num_workers=1,
)
model = MoLeR(configuration=config)
samples = list(model.sample(number_of_samples))
seed_mols = [] if scaffolds == "" else scaffolds.split(".")
return draw_grid_generate(seed_mols, samples)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: TITLE in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="MoLeR (MOlecule-LEvel Representation)",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="v0"),
gr.Textbox(
label="Scaffolds",
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
lines=1,
),
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beam_size"),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of samples", step=1
),
gr.Number(value=42, label="Seed", precision=0),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)