moler / app.py
jannisborn's picture
feat: moler updates
f1e36b5 unverified
raw
history blame
2.81 kB
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.moler import MoLeR, MoLeRDefaultGenerator
from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
TITLE = "MoLeR"
def run_inference(
algorithm_version: str,
scaffolds: str,
seed_smiles: str,
beam_size: int,
sigma: float,
number_of_samples: int,
seed: int,
):
config = MoLeRDefaultGenerator(
algorithm_version=algorithm_version,
scaffolds=scaffolds,
beam_size=beam_size,
num_samples=32,
seed=seed,
num_workers=1,
seed_smiles=seed_smiles,
sigma=sigma,
)
model = MoLeR(configuration=config)
samples = list(model.sample(number_of_samples))
scaffold_list = [] if scaffolds == "" else scaffolds.split(".")
seed_list = [] if seed_smiles == "" else seed_smiles.split(".")
return draw_grid_generate(seed_list, scaffold_list, samples)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: TITLE in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="MoLeR (MOlecule-LEvel Representation)",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="v0"),
gr.Textbox(
label="Scaffolds",
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
lines=1,
),
gr.Textbox(
label="Seed SMILES",
placeholder="O=C1C2=CC=C(C3=CC=CC=C3)C=C=C2OC2=CC=CC=C12",
lines=1,
),
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beams"),
gr.Slider(minimum=0.0, maximum=3.0, value=0.01, label="Sigma"),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of samples", step=1
),
gr.Number(value=42, label="Seed", precision=0),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)