Spaces:
Running
Running
import torch | |
import selfies as sf | |
from transformers import T5EncoderModel | |
from src.scripts.mytokenizers import Tokenizer | |
from src.improved_diffusion import gaussian_diffusion as gd | |
from src.improved_diffusion.respace import SpacedDiffusion | |
from src.improved_diffusion.transformer_model import TransformerNetModel | |
import streamlit as st | |
import spaces | |
import os | |
def get_encoder(device): | |
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
model.to(device) | |
model.eval() | |
return model | |
def get_tokenizer(): | |
return Tokenizer() | |
def get_model(device): | |
model = TransformerNetModel( | |
in_channels=32, | |
model_channels=128, | |
dropout=0.1, | |
vocab_size=35073, | |
hidden_size=1024, | |
num_attention_heads=16, | |
num_hidden_layers=12, | |
) | |
model.load_state_dict( | |
torch.load( | |
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"), | |
map_location=torch.device(device), | |
) | |
) | |
model.to(device) | |
model.eval() | |
return model | |
def get_diffusion(): | |
return SpacedDiffusion( | |
use_timesteps=[i for i in range(0, 2000, 10)], | |
betas=gd.get_named_beta_schedule("sqrt", 2000), | |
model_mean_type=(gd.ModelMeanType.START_X), | |
model_var_type=((gd.ModelVarType.FIXED_LARGE)), | |
loss_type=gd.LossType.E2E_MSE, | |
rescale_timesteps=True, | |
model_arch="transformer", | |
training_mode="e2e", | |
) | |
def generate(text_input): | |
with st.spinner("Please wait..."): | |
output = tokenizer( | |
text_input, | |
max_length=256, | |
truncation=True, | |
padding="max_length", | |
add_special_tokens=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
caption_state = encoder( | |
input_ids=output["input_ids"].to(device), | |
attention_mask=output["attention_mask"].to(device), | |
).last_hidden_state | |
caption_mask = output["attention_mask"] | |
outputs = diffusion.p_sample_loop( | |
model, | |
(1, 256, 32), | |
clip_denoised=False, | |
denoised_fn=None, | |
model_kwargs={}, | |
top_p=1.0, | |
progress=True, | |
caption=(caption_state.to(device), caption_mask.to(device)), | |
) | |
logits = model.get_logits(torch.tensor(outputs)) | |
cands = torch.topk(logits, k=1, dim=-1) | |
outputs = cands.indices | |
outputs = outputs.squeeze(-1) | |
outputs = tokenizer.decode(outputs) | |
result = sf.decoder( | |
outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "") | |
).replace("\t", "") | |
return result | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
tokenizer = get_tokenizer() | |
encoder = get_encoder(device) | |
model = get_model(device) | |
diffusion = get_diffusion() | |
st.title("Lang2mol-Diff") | |
text_input = st.text_area("Enter molecule description") | |
button = st.button("Submit") | |
if button: | |
result = generate(text_input) | |
st.write(result) | |