File size: 2,956 Bytes
09c907a
 
 
 
a4eba41
 
 
09c907a
 
 
 
 
 
 
 
 
 
 
a4eba41
 
09c907a
a4eba41
09c907a
a4eba41
09c907a
 
a4eba41
 
 
09c907a
a4eba41
 
09c907a
a4eba41
09c907a
 
a4eba41
09c907a
 
 
 
 
 
 
 
a4eba41
09c907a
 
 
 
 
a4eba41
 
 
09c907a
 
 
 
 
 
 
 
a4eba41
09c907a
a4eba41
 
 
 
09c907a
a4eba41
09c907a
a4eba41
 
09c907a
 
 
 
 
 
 
 
 
 
a4eba41
09c907a
 
a4eba41
09c907a
 
a4eba41
09c907a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.controlled_sampling.advanced_manufacturing import (
    CatalystGenerator,
    AdvancedManufacturing,
)
from gt4sd.algorithms.registry import ApplicationsRegistry

from utils import draw_grid_generate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(
    algorithm_version: str,
    target_binding_energy: float,
    primer_smiles: str,
    length: float,
    number_of_points: int,
    number_of_steps: int,
    number_of_samples: int,
):

    config = CatalystGenerator(
        algorithm_version=algorithm_version,
        number_of_points=number_of_points,
        number_of_steps=number_of_steps,
        generated_length=length,
        primer_smiles=primer_smiles,
    )
    model = AdvancedManufacturing(config, target=target_binding_energy)
    samples = list(model.sample(number_of_samples))

    return draw_grid_generate(samples=samples, n_cols=5, seeds=[])


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_version"]
        for x in list(filter(lambda x: "Advanced" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Advanced Manufacturing",
        inputs=[
            gr.Dropdown(
                algos,
                label="Algorithm version",
                value="NCCR_rnn_suzuki_aug16_smiles",
            ),
            gr.Slider(minimum=1, maximum=100, value=10, label="Target binding energy"),
            gr.Textbox(
                label="Primer SMILES",
                placeholder="FP(F)F.CP(C)c1ccccc1.[Au]",
                lines=1,
            ),
            gr.Slider(
                minimum=5,
                maximum=400,
                value=100,
                label="Maximal sequence length",
                step=1,
            ),
            gr.Slider(
                minimum=16, maximum=128, value=32, label="Number of points", step=1
            ),
            gr.Slider(
                minimum=16, maximum=128, value=50, label="Number of steps", step=1
            ),
            gr.Slider(
                minimum=1, maximum=50, value=10, label="Number of samples", step=1
            ),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)