File size: 3,344 Bytes
7dd9869
 
 
 
 
 
 
7cacf8f
ad32d4f
7dd9869
 
 
23a7a4b
7dd9869
23a7a4b
7dd9869
 
 
 
 
 
 
 
23a7a4b
7dd9869
 
 
 
 
 
 
 
 
 
688df3c
 
23a7a4b
688df3c
7dd9869
23a7a4b
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cacf8f
ad32d4f
 
7cacf8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad32d4f
 
 
 
 
 
 
7dd9869
7cacf8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import selfies as sf
from transformers import T5EncoderModel
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
import gradio as gr
import spaces
import os


def get_encoder(device):
    model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
    model.to(device)
    model.eval()
    return model


def get_tokenizer():
    return Tokenizer()


def get_model(device):
    model = TransformerNetModel(
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        vocab_size=35073,
        hidden_size=1024,
        num_attention_heads=16,
        num_hidden_layers=12,
    )
    model.load_state_dict(
        torch.load(
            os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
            map_location=torch.device(device),
        )
    )
    model.to(device)
    model.eval()
    return model


def get_diffusion():
    return SpacedDiffusion(
        use_timesteps=[i for i in range(0, 2000, 10)],
        betas=gd.get_named_beta_schedule("sqrt", 2000),
        model_mean_type=(gd.ModelMeanType.START_X),
        model_var_type=((gd.ModelVarType.FIXED_LARGE)),
        loss_type=gd.LossType.E2E_MSE,
        rescale_timesteps=True,
        model_arch="transformer",
        training_mode="e2e",
    )


@spaces.GPU
def generate(text_input):
    output = tokenizer(
        text_input,
        max_length=256,
        truncation=True,
        padding="max_length",
        add_special_tokens=True,
        return_tensors="pt",
        return_attention_mask=True,
    )
    caption_state = encoder(
        input_ids=output["input_ids"].to(device),
        attention_mask=output["attention_mask"].to(device),
    ).last_hidden_state
    caption_mask = output["attention_mask"]

    outputs = diffusion.p_sample_loop(
        model,
        (1, 256, 32),
        clip_denoised=False,
        denoised_fn=None,
        model_kwargs={},
        top_p=1.0,
        progress=True,
        caption=(caption_state.to(device), caption_mask.to(device)),
    )
    logits = model.get_logits(torch.tensor(outputs))
    cands = torch.topk(logits, k=1, dim=-1)
    outputs = cands.indices
    outputs = outputs.squeeze(-1)
    outputs = tokenizer.decode(outputs)
    result = sf.decoder(
        outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
    ).replace("\t", "")
    return result


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = get_tokenizer()
encoder = get_encoder(device)
model = get_model(device)
diffusion = get_diffusion()

# Create a Gradio interface
iface = gr.Interface(
    fn=generate,
    inputs="text",
    outputs="text",
    title="Lang2mol-Diff",
    description="Enter molecule description",
    examples=[
        [
            "The molecule is a apoptosis, cholesterol translocation, stabilizing mitochondrial structure that impacts barth syndrome and non-alcoholic fatty liver disease. The molecule is a stabilizing cytochrome oxidase and a proton trap for oxidative phosphorylation that impacts aging, diabetic heart disease, and tangier disease."
        ],
    ],
)

# Run the interface
iface.launch()