File size: 3,206 Bytes
7cacf8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import selfies as sf
from transformers import T5EncoderModel
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
import streamlit as st
import spaces
import os


@st.cache_resource
def get_encoder(device):
    model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
    model.to(device)
    model.eval()
    return model


@st.cache_resource
def get_tokenizer():
    return Tokenizer()


@st.cache_resource
def get_model(device):
    model = TransformerNetModel(
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        vocab_size=35073,
        hidden_size=1024,
        num_attention_heads=16,
        num_hidden_layers=12,
    )
    model.load_state_dict(
        torch.load(
            os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
            map_location=torch.device(device),
        )
    )
    model.to(device)
    model.eval()
    return model


@st.cache_resource
def get_diffusion():
    return SpacedDiffusion(
        use_timesteps=[i for i in range(0, 2000, 10)],
        betas=gd.get_named_beta_schedule("sqrt", 2000),
        model_mean_type=(gd.ModelMeanType.START_X),
        model_var_type=((gd.ModelVarType.FIXED_LARGE)),
        loss_type=gd.LossType.E2E_MSE,
        rescale_timesteps=True,
        model_arch="transformer",
        training_mode="e2e",
    )

@spaces.GPU
def generate(text_input):
    with st.spinner("Please wait..."):
        output = tokenizer(
                text_input,
                max_length=256,
                truncation=True,
                padding="max_length",
                add_special_tokens=True,
                return_tensors="pt",
                return_attention_mask=True,
            )
        caption_state = encoder(
            input_ids=output["input_ids"].to(device),
            attention_mask=output["attention_mask"].to(device),
        ).last_hidden_state
        caption_mask = output["attention_mask"]

        outputs = diffusion.p_sample_loop(
            model,
            (1, 256, 32),
            clip_denoised=False,
            denoised_fn=None,
            model_kwargs={},
            top_p=1.0,
            progress=True,
            caption=(caption_state.to(device), caption_mask.to(device)),
        )
        logits = model.get_logits(torch.tensor(outputs))
        cands = torch.topk(logits, k=1, dim=-1)
        outputs = cands.indices
        outputs = outputs.squeeze(-1)
        outputs = tokenizer.decode(outputs)
        result = sf.decoder(
            outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
        ).replace("\t", "")
        return result

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer = get_tokenizer()
encoder = get_encoder(device)
model = get_model(device)
diffusion = get_diffusion()

st.title("Lang2mol-Diff")
text_input = st.text_area("Enter molecule description")
button = st.button("Submit")
if button:
    result = generate(text_input)
    st.write(result)