Lang2mol-Diff / app.py
ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
2.83 kB
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
model_and_diffusion_defaults,
add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_submission
import streamlit as st
import os
@st.cache_resource
def get_encoder():
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
model.eval()
return model
@st.cache_resource
def get_tokenizer():
return Tokenizer()
@st.cache_resource
def get_model():
model = TransformerNetModel(
in_channels=32,
model_channels=128,
dropout=0.1,
vocab_size=35073,
hidden_size=1024,
num_attention_heads=16,
num_hidden_layers=12,
)
model.load_state_dict(
dist_util.load_state_dict(
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
map_location="cpu",
)
)
model.eval()
return model
@st.cache_resource
def get_diffusion():
return SpacedDiffusion(
use_timesteps=[i for i in range(0, 2000, 10)],
betas=gd.get_named_beta_schedule("sqrt", 2000),
model_mean_type=(gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE)),
loss_type=gd.LossType.E2E_MSE,
rescale_timesteps=True,
model_arch="transformer",
training_mode="e2e",
)
tokenizer = get_tokenizer()
encoder = get_encoder()
model = get_model()
diffusion = get_diffusion()
sample_fn = diffusion.ddim_sample_loop
text_input = st.text_area("Enter molecule description")
output = tokenizer(
text_input,
max_length=256,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
caption_state = encoder(
input_ids=output["input_ids"],
attention_mask=output["attention_mask"],
).last_hidden_state
caption_mask = output["attention_mask"]
outputs = sample_fn(
model,
(1, 256, 32),
clip_denoised=False,
denoised_fn=None,
model_kwargs={},
top_p=1.0,
progress=True,
caption=(caption_state, caption_mask),
)
logits = model.get_logits(torch.tensor(outputs))
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
result = sf.decoder(
outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
).replace("\t", "")
st.write(result)