File size: 2,375 Bytes
7d76d6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
import pathlib

import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.moler import MoLeR, MoLeRDefaultGenerator

from gt4sd.algorithms.registry import ApplicationsRegistry
from utils import draw_grid_generate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

TITLE = "MoLeR"


def run_inference(
    algorithm_version: str,
    scaffolds: str,
    beam_size: int,
    number_of_samples: int,
    seed: int,
):
    config = MoLeRDefaultGenerator(
        algorithm_version=algorithm_version,
        scaffolds=scaffolds,
        beam_size=beam_size,
        num_samples=4,
        seed=seed,
        num_workers=1,
    )
    model = MoLeR(configuration=config)
    samples = list(model.sample(number_of_samples))

    seed_mols = [] if scaffolds == "" else scaffolds.split(".")
    return draw_grid_generate(seed_mols, samples)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_version"]
        for x in list(filter(lambda x: TITLE in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="MoLeR (MOlecule-LEvel Representation)",
        inputs=[
            gr.Dropdown(algos, label="Algorithm version", value="v0"),
            gr.Textbox(
                label="Scaffolds",
                placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
                lines=1,
            ),
            gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beam_size"),
            gr.Slider(
                minimum=1, maximum=50, value=10, label="Number of samples", step=1
            ),
            gr.Number(value=42, label="Seed", precision=0),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)