import torch import argparse import selfies as sf from tqdm import tqdm from transformers import T5EncoderModel from transformers import set_seed from src.scripts.mytokenizers import Tokenizer from src.improved_diffusion import gaussian_diffusion as gd from src.improved_diffusion import dist_util, logger from src.improved_diffusion.respace import SpacedDiffusion from src.improved_diffusion.transformer_model import TransformerNetModel from src.improved_diffusion.script_util import ( model_and_diffusion_defaults, add_dict_to_argparser, ) from src.scripts.mydatasets import Lang2molDataset_submission import streamlit as st import os @st.cache_resource def get_encoder(): model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") model.eval() return model @st.cache_resource def get_tokenizer(): return Tokenizer() @st.cache_resource def get_model(): model = TransformerNetModel( in_channels=32, model_channels=128, dropout=0.1, vocab_size=35073, hidden_size=1024, num_attention_heads=16, num_hidden_layers=12, ) model.load_state_dict( dist_util.load_state_dict( os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"), map_location="cpu", ) ) model.eval() return model @st.cache_resource def get_diffusion(): return SpacedDiffusion( use_timesteps=[i for i in range(0, 2000, 10)], betas=gd.get_named_beta_schedule("sqrt", 2000), model_mean_type=(gd.ModelMeanType.START_X), model_var_type=((gd.ModelVarType.FIXED_LARGE)), loss_type=gd.LossType.E2E_MSE, rescale_timesteps=True, model_arch="transformer", training_mode="e2e", ) tokenizer = get_tokenizer() encoder = get_encoder() model = get_model() diffusion = get_diffusion() sample_fn = diffusion.ddim_sample_loop text_input = st.text_area("Enter molecule description") output = tokenizer( text_input, max_length=256, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) caption_state = encoder( input_ids=output["input_ids"], attention_mask=output["attention_mask"], ).last_hidden_state caption_mask = output["attention_mask"] outputs = sample_fn( model, (1, 256, 32), clip_denoised=False, denoised_fn=None, model_kwargs={}, top_p=1.0, progress=True, caption=(caption_state, caption_mask), ) logits = model.get_logits(torch.tensor(outputs)) cands = torch.topk(logits, k=1, dim=-1) outputs = cands.indices outputs = outputs.squeeze(-1) outputs = tokenizer.decode(outputs) result = sf.decoder( outputs[0].replace("", "").replace("", "").replace("\t", "") ).replace("\t", "") st.write(result)